falcon.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Falcon model."""
  20. import math
  21. from typing import Iterable, List, Optional, Tuple, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import LayerNorm
  25. from transformers import FalconConfig as HF_FalconConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.sequence import SamplerOutput
  28. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  29. get_tensor_model_parallel_world_size,
  30. tensor_model_parallel_all_reduce)
  31. from aphrodite.modeling.layers.activation import get_act_fn
  32. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  39. VocabParallelEmbedding
  40. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.quantization.base_config import QuantizationConfig
  43. from aphrodite.transformers_utils.configs import RWConfig
  44. FalconConfig = Union[HF_FalconConfig, RWConfig]
  45. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  46. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  47. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  48. dtype=torch.float32)
  49. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  50. slopes = torch.pow(base, powers)
  51. if closest_power_of_2 != total_num_heads:
  52. extra_base = torch.tensor(
  53. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  54. dtype=torch.float32)
  55. num_remaining_heads = min(closest_power_of_2,
  56. total_num_heads - closest_power_of_2)
  57. extra_powers = torch.arange(1,
  58. 1 + 2 * num_remaining_heads,
  59. 2,
  60. dtype=torch.int32)
  61. slopes = torch.cat(
  62. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  63. return slopes
  64. class FalconAttention(nn.Module):
  65. def __init__(
  66. self,
  67. config: FalconConfig,
  68. quant_config: Optional[QuantizationConfig] = None,
  69. ):
  70. super().__init__()
  71. self.hidden_size = config.hidden_size
  72. tp_size = get_tensor_model_parallel_world_size()
  73. self.total_num_heads = config.num_attention_heads
  74. assert self.total_num_heads % tp_size == 0
  75. self.num_heads = self.total_num_heads // tp_size
  76. self.head_dim = self.hidden_size // self.total_num_heads
  77. assert self.head_dim * self.total_num_heads == self.hidden_size
  78. self.new_decoder_architecture = config.new_decoder_architecture
  79. self.multi_query = config.multi_query
  80. if self.new_decoder_architecture:
  81. self.total_num_kv_heads = config.num_kv_heads
  82. elif self.multi_query:
  83. self.total_num_kv_heads = 1
  84. else:
  85. self.total_num_kv_heads = self.total_num_heads
  86. if self.total_num_kv_heads >= tp_size:
  87. # Number of KV heads is greater than TP size, so we partition
  88. # the KV heads across multiple tensor parallel GPUs.
  89. assert self.total_num_kv_heads % tp_size == 0
  90. else:
  91. # Number of KV heads is less than TP size, so we replicate
  92. # the KV heads across multiple tensor parallel GPUs.
  93. assert tp_size % self.total_num_kv_heads == 0
  94. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  95. self.query_key_value = QKVParallelLinear(
  96. self.hidden_size,
  97. self.head_dim,
  98. self.total_num_heads,
  99. self.total_num_kv_heads,
  100. bias=config.bias,
  101. skip_bias_add=True,
  102. quant_config=quant_config,
  103. )
  104. self.q_size = self.num_heads * self.head_dim
  105. self.kv_size = self.num_kv_heads * self.head_dim
  106. # Layer-wise attention scaling
  107. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  108. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  109. or config.parallel_attn)
  110. self.dense = RowParallelLinear(
  111. self.hidden_size,
  112. self.hidden_size,
  113. bias=config.bias,
  114. skip_bias_add=True,
  115. quant_config=quant_config,
  116. reduce_results=self.reduce_row_parallel_results)
  117. self.use_rotary = config.rotary
  118. self.use_alibi = config.alibi
  119. assert not (self.use_rotary and self.use_alibi), (
  120. "Rotary and alibi are mutually exclusive.")
  121. if self.use_rotary:
  122. rope_theta = getattr(config, "rope_theta", 10000)
  123. max_position_embeddings = getattr(config,
  124. "max_position_embeddings", 8192)
  125. self.rotary_emb = get_rope(
  126. self.head_dim,
  127. rotary_dim=self.head_dim,
  128. max_position=max_position_embeddings,
  129. base=rope_theta,
  130. )
  131. self.attn = Attention(self.num_heads,
  132. self.head_dim,
  133. self.inv_norm_factor,
  134. num_kv_heads=self.num_kv_heads)
  135. elif self.use_alibi:
  136. tp_rank = get_tensor_model_parallel_rank()
  137. head_start = tp_rank * self.num_heads
  138. head_end = (tp_rank + 1) * self.num_heads
  139. alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
  140. self.inv_norm_factor)
  141. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  142. self.attn = Attention(self.num_heads,
  143. self.head_dim,
  144. self.inv_norm_factor,
  145. num_kv_heads=self.num_kv_heads,
  146. alibi_slopes=alibi_slopes)
  147. else:
  148. self.attn = Attention(self.num_heads,
  149. self.head_dim,
  150. scale=self.inv_norm_factor,
  151. num_kv_heads=self.num_kv_heads)
  152. def forward(
  153. self,
  154. positions: torch.Tensor,
  155. hidden_states: torch.Tensor,
  156. kv_cache: torch.Tensor,
  157. attn_metadata: AttentionMetadata,
  158. ) -> torch.Tensor:
  159. qkv, bias = self.query_key_value(hidden_states)
  160. if bias is not None:
  161. qkv += bias
  162. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  163. if self.use_rotary:
  164. q, k = self.rotary_emb(positions, q, k)
  165. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  166. attn_output, bias = self.dense(attn_output)
  167. return attn_output, bias
  168. class FalconMLP(nn.Module):
  169. def __init__(
  170. self,
  171. config: FalconConfig,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. ):
  174. super().__init__()
  175. hidden_size = config.hidden_size
  176. self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
  177. 4 * hidden_size,
  178. bias=config.bias,
  179. skip_bias_add=True,
  180. quant_config=quant_config)
  181. quant_config = getattr(quant_config, "quant_config", None)
  182. self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
  183. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  184. or config.parallel_attn)
  185. self.dense_4h_to_h = RowParallelLinear(
  186. 4 * hidden_size,
  187. hidden_size,
  188. bias=config.bias,
  189. skip_bias_add=True,
  190. reduce_results=self.reduce_row_parallel_results,
  191. quant_config=quant_config)
  192. def forward(self, x: torch.Tensor) -> torch.Tensor:
  193. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
  194. x, bias = self.dense_h_to_4h(x)
  195. if bias is not None:
  196. x += bias
  197. x = self.act(x)
  198. x, bias = self.dense_4h_to_h(x)
  199. return x, bias
  200. class FalconDecoderLayer(nn.Module):
  201. def __init__(
  202. self,
  203. config: FalconConfig,
  204. quant_config: Optional[QuantizationConfig] = None,
  205. ):
  206. super().__init__()
  207. hidden_size = config.hidden_size
  208. self.num_heads = config.num_attention_heads
  209. self.self_attention = FalconAttention(config, quant_config)
  210. self.mlp = FalconMLP(config, quant_config)
  211. self.config = config
  212. if config.new_decoder_architecture:
  213. # The layer norm before self-attention
  214. self.ln_attn = LayerNorm(hidden_size,
  215. eps=config.layer_norm_epsilon)
  216. # The layer norm before the MLP
  217. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  218. else:
  219. self.input_layernorm = LayerNorm(hidden_size,
  220. eps=config.layer_norm_epsilon)
  221. if not config.parallel_attn:
  222. self.post_attention_layernorm = LayerNorm(
  223. hidden_size, eps=config.layer_norm_epsilon)
  224. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  225. or config.parallel_attn)
  226. def forward(
  227. self,
  228. positions: torch.Tensor,
  229. hidden_states: torch.Tensor,
  230. kv_cache: torch.Tensor,
  231. attn_metadata: AttentionMetadata,
  232. ) -> torch.Tensor:
  233. residual = hidden_states
  234. if self.config.new_decoder_architecture:
  235. attention_layernorm_out = self.ln_attn(hidden_states)
  236. mlp_layernorm_out = self.ln_mlp(hidden_states)
  237. else:
  238. attention_layernorm_out = self.input_layernorm(hidden_states)
  239. # Self attention.
  240. attention_output, attention_bias = self.self_attention(
  241. positions=positions,
  242. hidden_states=attention_layernorm_out,
  243. kv_cache=kv_cache,
  244. attn_metadata=attn_metadata,
  245. )
  246. if self.reduce_row_parallel_results and attention_bias is not None:
  247. attention_output += attention_bias
  248. if not self.config.new_decoder_architecture:
  249. if self.config.parallel_attn:
  250. mlp_layernorm_out = attention_layernorm_out
  251. else:
  252. residual += attention_output
  253. mlp_layernorm_out = self.post_attention_layernorm(residual)
  254. # MLP.
  255. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
  256. if self.reduce_row_parallel_results and mlp_bias is not None:
  257. mlp_output += mlp_bias
  258. if not self.reduce_row_parallel_results:
  259. # When MLP and Attention layers are parallel, we can use
  260. # only one all-reduce operator to reduce the results from
  261. # both MLP and Attention layers.
  262. mlp_output += attention_output
  263. mlp_output = tensor_model_parallel_all_reduce(mlp_output)
  264. if attention_bias is not None:
  265. mlp_output += attention_bias
  266. if mlp_bias is not None:
  267. mlp_output += mlp_bias
  268. output = mlp_output + residual
  269. return output
  270. class FalconModel(nn.Module):
  271. def __init__(
  272. self,
  273. config: FalconConfig,
  274. quant_config: Optional[QuantizationConfig] = None,
  275. ):
  276. super().__init__()
  277. self.config = config
  278. self.embed_dim = config.hidden_size
  279. self.num_heads = config.num_attention_heads
  280. self.use_alibi = config.alibi
  281. # Embedding + LN Embedding
  282. self.word_embeddings = VocabParallelEmbedding(
  283. config.vocab_size,
  284. self.embed_dim,
  285. )
  286. # Transformer blocks
  287. self.h = nn.ModuleList([
  288. FalconDecoderLayer(config, quant_config)
  289. for _ in range(config.num_hidden_layers)
  290. ])
  291. # Final Layer Norm
  292. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  293. def forward(
  294. self,
  295. input_ids: torch.LongTensor,
  296. positions: torch.Tensor,
  297. kv_caches: List[torch.Tensor],
  298. attn_metadata: AttentionMetadata,
  299. ) -> torch.Tensor:
  300. hidden_states = self.word_embeddings(input_ids)
  301. for i in range(len(self.h)):
  302. layer = self.h[i]
  303. hidden_states = layer(
  304. positions,
  305. hidden_states,
  306. kv_caches[i],
  307. attn_metadata,
  308. )
  309. hidden_states = self.ln_f(hidden_states)
  310. return hidden_states
  311. class FalconForCausalLM(nn.Module):
  312. def __init__(
  313. self,
  314. config: FalconConfig,
  315. quant_config: Optional[QuantizationConfig] = None,
  316. ):
  317. super().__init__()
  318. self.config = config
  319. self.quant_config = quant_config
  320. self.transformer = FalconModel(config, quant_config)
  321. self.lm_head_weight = self.transformer.word_embeddings.weight
  322. self.logits_processor = LogitsProcessor(config.vocab_size)
  323. self.sampler = Sampler()
  324. def forward(
  325. self,
  326. input_ids: torch.LongTensor,
  327. positions: torch.Tensor,
  328. kv_caches: List[torch.Tensor],
  329. attn_metadata: AttentionMetadata,
  330. ) -> torch.Tensor:
  331. hidden_states = self.transformer(
  332. input_ids,
  333. positions,
  334. kv_caches,
  335. attn_metadata,
  336. )
  337. return hidden_states
  338. def compute_logits(self, hidden_states: torch.Tensor,
  339. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  340. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  341. sampling_metadata)
  342. return logits
  343. def sample(
  344. self,
  345. logits: torch.Tensor,
  346. sampling_metadata: SamplingMetadata,
  347. ) -> Optional[SamplerOutput]:
  348. next_tokens = self.sampler(logits, sampling_metadata)
  349. return next_tokens
  350. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  351. total_num_heads = self.config.num_attention_heads
  352. if self.config.new_decoder_architecture:
  353. total_num_kv_heads = self.config.num_kv_heads
  354. elif self.config.multi_query:
  355. total_num_kv_heads = 1
  356. else:
  357. total_num_kv_heads = total_num_heads
  358. num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
  359. params_dict = dict(self.named_parameters(remove_duplicate=False))
  360. for name, loaded_weight in weights:
  361. if name == "lm_head.weight":
  362. # Falcon uses tied embeddings.
  363. continue
  364. # Skip loading extra bias for GPTQ models.
  365. if name.endswith(".bias") and name not in params_dict:
  366. continue
  367. param = params_dict[name]
  368. if "query_key_value" in name:
  369. output_dim = getattr(param, "output_dim", None)
  370. loaded_weight_shape = loaded_weight.shape
  371. if output_dim is not None:
  372. loaded_weight = loaded_weight.view(
  373. loaded_weight_shape[:output_dim] +
  374. (total_num_kv_heads, num_query_heads_per_kv_head + 2,
  375. -1) + loaded_weight_shape[output_dim + 1:])
  376. wq = loaded_weight.narrow(
  377. output_dim + 1, 0,
  378. num_query_heads_per_kv_head).reshape(
  379. *loaded_weight_shape[:output_dim], -1,
  380. *loaded_weight_shape[output_dim + 1:])
  381. wk = loaded_weight.narrow(
  382. output_dim + 1, num_query_heads_per_kv_head,
  383. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  384. *loaded_weight_shape[output_dim + 1:])
  385. wv = loaded_weight.narrow(
  386. output_dim + 1, num_query_heads_per_kv_head + 1,
  387. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  388. *loaded_weight_shape[output_dim + 1:])
  389. loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
  390. weight_loader = getattr(param, "weight_loader",
  391. default_weight_loader)
  392. weight_loader(param, loaded_weight)