falcon.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Falcon model."""
  20. import math
  21. from typing import Iterable, List, Optional, Tuple, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import LayerNorm
  25. from transformers import FalconConfig as HF_FalconConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import CacheConfig
  28. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  29. from aphrodite.common.utils import progress_bar
  30. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  31. get_tensor_model_parallel_world_size,
  32. tensor_model_parallel_all_reduce)
  33. from aphrodite.modeling.layers.activation import get_act_fn
  34. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. VocabParallelEmbedding)
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. from aphrodite.transformers_utils.configs import RWConfig
  46. FalconConfig = Union[HF_FalconConfig, RWConfig]
  47. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  48. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  49. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  50. dtype=torch.float32)
  51. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  52. slopes = torch.pow(base, powers)
  53. if closest_power_of_2 != total_num_heads:
  54. extra_base = torch.tensor(
  55. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  56. dtype=torch.float32)
  57. num_remaining_heads = min(closest_power_of_2,
  58. total_num_heads - closest_power_of_2)
  59. extra_powers = torch.arange(1,
  60. 1 + 2 * num_remaining_heads,
  61. 2,
  62. dtype=torch.int32)
  63. slopes = torch.cat(
  64. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  65. return slopes
  66. class FalconAttention(nn.Module):
  67. def __init__(
  68. self,
  69. config: FalconConfig,
  70. cache_config: Optional[CacheConfig] = None,
  71. quant_config: Optional[QuantizationConfig] = None,
  72. ):
  73. super().__init__()
  74. self.hidden_size = config.hidden_size
  75. tp_size = get_tensor_model_parallel_world_size()
  76. self.total_num_heads = config.num_attention_heads
  77. assert self.total_num_heads % tp_size == 0
  78. self.num_heads = self.total_num_heads // tp_size
  79. self.head_dim = self.hidden_size // self.total_num_heads
  80. assert self.head_dim * self.total_num_heads == self.hidden_size
  81. self.new_decoder_architecture = config.new_decoder_architecture
  82. self.multi_query = config.multi_query
  83. if self.new_decoder_architecture:
  84. self.total_num_kv_heads = config.num_kv_heads
  85. elif self.multi_query:
  86. self.total_num_kv_heads = 1
  87. else:
  88. self.total_num_kv_heads = self.total_num_heads
  89. if self.total_num_kv_heads >= tp_size:
  90. # Number of KV heads is greater than TP size, so we partition
  91. # the KV heads across multiple tensor parallel GPUs.
  92. assert self.total_num_kv_heads % tp_size == 0
  93. else:
  94. # Number of KV heads is less than TP size, so we replicate
  95. # the KV heads across multiple tensor parallel GPUs.
  96. assert tp_size % self.total_num_kv_heads == 0
  97. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  98. self.query_key_value = QKVParallelLinear(
  99. self.hidden_size,
  100. self.head_dim,
  101. self.total_num_heads,
  102. self.total_num_kv_heads,
  103. bias=config.bias,
  104. skip_bias_add=True,
  105. quant_config=quant_config,
  106. )
  107. self.q_size = self.num_heads * self.head_dim
  108. self.kv_size = self.num_kv_heads * self.head_dim
  109. # Layer-wise attention scaling
  110. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  111. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  112. or config.parallel_attn)
  113. self.dense = RowParallelLinear(
  114. self.hidden_size,
  115. self.hidden_size,
  116. bias=config.bias,
  117. skip_bias_add=True,
  118. quant_config=quant_config,
  119. reduce_results=self.reduce_row_parallel_results)
  120. self.use_rotary = config.rotary
  121. self.use_alibi = config.alibi
  122. assert not (self.use_rotary and self.use_alibi), (
  123. "Rotary and alibi are mutually exclusive.")
  124. if self.use_rotary:
  125. rope_theta = getattr(config, "rope_theta", 10000)
  126. max_position_embeddings = getattr(config,
  127. "max_position_embeddings", 8192)
  128. self.rotary_emb = get_rope(
  129. self.head_dim,
  130. rotary_dim=self.head_dim,
  131. max_position=max_position_embeddings,
  132. base=rope_theta,
  133. )
  134. self.attn = Attention(self.num_heads,
  135. self.head_dim,
  136. self.inv_norm_factor,
  137. num_kv_heads=self.num_kv_heads,
  138. quant_config=quant_config)
  139. elif self.use_alibi:
  140. tp_rank = get_tensor_model_parallel_rank()
  141. head_start = tp_rank * self.num_heads
  142. head_end = (tp_rank + 1) * self.num_heads
  143. alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
  144. self.inv_norm_factor)
  145. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  146. self.attn = Attention(self.num_heads,
  147. self.head_dim,
  148. self.inv_norm_factor,
  149. num_kv_heads=self.num_kv_heads,
  150. alibi_slopes=alibi_slopes,
  151. quant_config=quant_config)
  152. else:
  153. self.attn = Attention(self.num_heads,
  154. self.head_dim,
  155. scale=self.inv_norm_factor,
  156. num_kv_heads=self.num_kv_heads,
  157. cache_config=cache_config,
  158. quant_config=quant_config)
  159. def forward(
  160. self,
  161. positions: torch.Tensor,
  162. hidden_states: torch.Tensor,
  163. kv_cache: torch.Tensor,
  164. attn_metadata: AttentionMetadata,
  165. ) -> torch.Tensor:
  166. qkv, bias = self.query_key_value(hidden_states)
  167. if bias is not None:
  168. qkv += bias
  169. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  170. if self.use_rotary:
  171. q, k = self.rotary_emb(positions, q, k)
  172. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  173. attn_output, bias = self.dense(attn_output)
  174. return attn_output, bias
  175. class FalconMLP(nn.Module):
  176. def __init__(
  177. self,
  178. config: FalconConfig,
  179. quant_config: Optional[QuantizationConfig] = None,
  180. ):
  181. super().__init__()
  182. hidden_size = config.hidden_size
  183. self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
  184. 4 * hidden_size,
  185. bias=config.bias,
  186. skip_bias_add=True,
  187. quant_config=quant_config)
  188. self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
  189. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  190. or config.parallel_attn)
  191. self.dense_4h_to_h = RowParallelLinear(
  192. 4 * hidden_size,
  193. hidden_size,
  194. bias=config.bias,
  195. skip_bias_add=True,
  196. reduce_results=self.reduce_row_parallel_results,
  197. quant_config=quant_config)
  198. def forward(self, x: torch.Tensor) -> torch.Tensor:
  199. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
  200. x, bias = self.dense_h_to_4h(x)
  201. if bias is not None:
  202. x += bias
  203. x = self.act(x)
  204. x, bias = self.dense_4h_to_h(x)
  205. return x, bias
  206. class FalconDecoderLayer(nn.Module):
  207. def __init__(
  208. self,
  209. config: FalconConfig,
  210. cache_config: Optional[CacheConfig] = None,
  211. quant_config: Optional[QuantizationConfig] = None,
  212. ):
  213. super().__init__()
  214. hidden_size = config.hidden_size
  215. self.num_heads = config.num_attention_heads
  216. self.self_attention = FalconAttention(config, cache_config,
  217. quant_config)
  218. self.mlp = FalconMLP(config, quant_config)
  219. self.config = config
  220. if config.new_decoder_architecture:
  221. # The layer norm before self-attention
  222. self.ln_attn = LayerNorm(hidden_size,
  223. eps=config.layer_norm_epsilon)
  224. # The layer norm before the MLP
  225. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  226. else:
  227. self.input_layernorm = LayerNorm(hidden_size,
  228. eps=config.layer_norm_epsilon)
  229. if not config.parallel_attn:
  230. self.post_attention_layernorm = LayerNorm(
  231. hidden_size, eps=config.layer_norm_epsilon)
  232. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  233. or config.parallel_attn)
  234. def forward(
  235. self,
  236. positions: torch.Tensor,
  237. hidden_states: torch.Tensor,
  238. kv_cache: torch.Tensor,
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. residual = hidden_states
  242. if self.config.new_decoder_architecture:
  243. attention_layernorm_out = self.ln_attn(hidden_states)
  244. mlp_layernorm_out = self.ln_mlp(hidden_states)
  245. else:
  246. attention_layernorm_out = self.input_layernorm(hidden_states)
  247. # Self attention.
  248. attention_output, attention_bias = self.self_attention(
  249. positions=positions,
  250. hidden_states=attention_layernorm_out,
  251. kv_cache=kv_cache,
  252. attn_metadata=attn_metadata,
  253. )
  254. if self.reduce_row_parallel_results and attention_bias is not None:
  255. attention_output += attention_bias
  256. if not self.config.new_decoder_architecture:
  257. if self.config.parallel_attn:
  258. mlp_layernorm_out = attention_layernorm_out
  259. else:
  260. residual += attention_output
  261. mlp_layernorm_out = self.post_attention_layernorm(residual)
  262. # MLP.
  263. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
  264. if self.reduce_row_parallel_results and mlp_bias is not None:
  265. mlp_output += mlp_bias
  266. if not self.reduce_row_parallel_results:
  267. # When MLP and Attention layers are parallel, we can use
  268. # only one all-reduce operator to reduce the results from
  269. # both MLP and Attention layers.
  270. mlp_output += attention_output
  271. mlp_output = tensor_model_parallel_all_reduce(mlp_output)
  272. if attention_bias is not None:
  273. mlp_output += attention_bias
  274. if mlp_bias is not None:
  275. mlp_output += mlp_bias
  276. output = mlp_output + residual
  277. return output
  278. class FalconModel(nn.Module):
  279. def __init__(
  280. self,
  281. config: FalconConfig,
  282. cache_config: Optional[CacheConfig] = None,
  283. quant_config: Optional[QuantizationConfig] = None,
  284. ):
  285. super().__init__()
  286. self.config = config
  287. self.embed_dim = config.hidden_size
  288. self.num_heads = config.num_attention_heads
  289. self.use_alibi = config.alibi
  290. # Embedding + LN Embedding
  291. self.word_embeddings = VocabParallelEmbedding(
  292. config.vocab_size,
  293. self.embed_dim,
  294. )
  295. # Transformer blocks
  296. self.h = nn.ModuleList([
  297. FalconDecoderLayer(config, cache_config, quant_config)
  298. for _ in range(config.num_hidden_layers)
  299. ])
  300. # Final Layer Norm
  301. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  302. def forward(
  303. self,
  304. input_ids: torch.LongTensor,
  305. positions: torch.Tensor,
  306. kv_caches: List[torch.Tensor],
  307. attn_metadata: AttentionMetadata,
  308. ) -> torch.Tensor:
  309. hidden_states = self.word_embeddings(input_ids)
  310. for i in range(len(self.h)):
  311. layer = self.h[i]
  312. hidden_states = layer(
  313. positions,
  314. hidden_states,
  315. kv_caches[i],
  316. attn_metadata,
  317. )
  318. hidden_states = self.ln_f(hidden_states)
  319. return hidden_states
  320. class FalconForCausalLM(nn.Module):
  321. def __init__(
  322. self,
  323. config: FalconConfig,
  324. cache_config: Optional[CacheConfig] = None,
  325. quant_config: Optional[QuantizationConfig] = None,
  326. ):
  327. super().__init__()
  328. self.config = config
  329. self.quant_config = quant_config
  330. self.transformer = FalconModel(config, cache_config, quant_config)
  331. self.lm_head = self.transformer.word_embeddings
  332. self.logits_processor = LogitsProcessor(config.vocab_size)
  333. self.sampler = Sampler()
  334. def forward(
  335. self,
  336. input_ids: torch.LongTensor,
  337. positions: torch.Tensor,
  338. kv_caches: List[torch.Tensor],
  339. attn_metadata: AttentionMetadata,
  340. intermediate_tensors: Optional[IntermediateTensors] = None,
  341. ) -> torch.Tensor:
  342. hidden_states = self.transformer(
  343. input_ids,
  344. positions,
  345. kv_caches,
  346. attn_metadata,
  347. )
  348. return hidden_states
  349. def compute_logits(
  350. self,
  351. hidden_states: torch.Tensor,
  352. sampling_metadata: SamplingMetadata,
  353. ) -> Optional[torch.Tensor]:
  354. logits = self.logits_processor(self.lm_head, hidden_states,
  355. sampling_metadata)
  356. return logits
  357. def sample(
  358. self,
  359. logits: torch.Tensor,
  360. sampling_metadata: SamplingMetadata,
  361. ) -> Optional[SamplerOutput]:
  362. next_tokens = self.sampler(logits, sampling_metadata)
  363. return next_tokens
  364. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  365. total_num_heads = self.config.num_attention_heads
  366. if self.config.new_decoder_architecture:
  367. total_num_kv_heads = self.config.num_kv_heads
  368. elif self.config.multi_query:
  369. total_num_kv_heads = 1
  370. else:
  371. total_num_kv_heads = total_num_heads
  372. num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
  373. params_dict = dict(self.named_parameters(remove_duplicate=False))
  374. weights_list = list(weights)
  375. for name, loaded_weight in progress_bar(weights_list,
  376. desc="Loading modules..."):
  377. if name == "lm_head.weight":
  378. # Falcon uses tied embeddings.
  379. continue
  380. # Skip loading extra bias for GPTQ models.
  381. if name.endswith(".bias") and name not in params_dict:
  382. continue
  383. param = params_dict[name]
  384. if "query_key_value" in name:
  385. output_dim = getattr(param, "output_dim", None)
  386. loaded_weight_shape = loaded_weight.shape
  387. if output_dim is not None:
  388. loaded_weight = loaded_weight.view(
  389. loaded_weight_shape[:output_dim] +
  390. (total_num_kv_heads, num_query_heads_per_kv_head + 2,
  391. -1) + loaded_weight_shape[output_dim + 1:])
  392. wq = loaded_weight.narrow(
  393. output_dim + 1, 0,
  394. num_query_heads_per_kv_head).reshape(
  395. *loaded_weight_shape[:output_dim], -1,
  396. *loaded_weight_shape[output_dim + 1:])
  397. wk = loaded_weight.narrow(
  398. output_dim + 1, num_query_heads_per_kv_head,
  399. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  400. *loaded_weight_shape[output_dim + 1:])
  401. wv = loaded_weight.narrow(
  402. output_dim + 1, num_query_heads_per_kv_head + 1,
  403. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  404. *loaded_weight_shape[output_dim + 1:])
  405. loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
  406. weight_loader = getattr(param, "weight_loader",
  407. default_weight_loader)
  408. weight_loader(param, loaded_weight)