falcon.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Falcon model."""
  20. import math
  21. from typing import List, Optional, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import LayerNorm
  25. from transformers import FalconConfig as HF_FalconConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.modeling.layers.activation import get_act_fn
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. ParallelLMHead, VocabParallelEmbedding)
  37. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  38. get_tensor_model_parallel_world_size,
  39. tensor_model_parallel_all_reduce)
  40. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  41. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  42. hf_model_weights_iterator)
  43. from aphrodite.common.sequence import SamplerOutput
  44. from aphrodite.transformers_utils.configs import RWConfig
  45. FalconConfig = Union[HF_FalconConfig, RWConfig]
  46. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  47. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  48. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  49. dtype=torch.float32)
  50. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  51. slopes = torch.pow(base, powers)
  52. if closest_power_of_2 != total_num_heads:
  53. extra_base = torch.tensor(
  54. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  55. dtype=torch.float32)
  56. num_remaining_heads = min(closest_power_of_2,
  57. total_num_heads - closest_power_of_2)
  58. extra_powers = torch.arange(1,
  59. 1 + 2 * num_remaining_heads,
  60. 2,
  61. dtype=torch.int32)
  62. slopes = torch.cat(
  63. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  64. return slopes
  65. class FalconAttention(nn.Module):
  66. def __init__(
  67. self,
  68. config: FalconConfig,
  69. linear_method: Optional[LinearMethodBase] = None,
  70. ):
  71. super().__init__()
  72. self.hidden_size = config.hidden_size
  73. tp_size = get_tensor_model_parallel_world_size()
  74. self.total_num_heads = config.num_attention_heads
  75. assert self.total_num_heads % tp_size == 0
  76. self.num_heads = self.total_num_heads // tp_size
  77. self.head_dim = self.hidden_size // self.total_num_heads
  78. assert self.head_dim * self.total_num_heads == self.hidden_size
  79. self.new_decoder_architecture = config.new_decoder_architecture
  80. self.multi_query = config.multi_query
  81. if self.new_decoder_architecture:
  82. self.total_num_kv_heads = config.num_kv_heads
  83. elif self.multi_query:
  84. self.total_num_kv_heads = 1
  85. else:
  86. self.total_num_kv_heads = self.total_num_heads
  87. if self.total_num_kv_heads >= tp_size:
  88. # Number of KV heads is greater than TP size, so we partition
  89. # the KV heads across multiple tensor parallel GPUs.
  90. assert self.total_num_kv_heads % tp_size == 0
  91. else:
  92. # Number of KV heads is less than TP size, so we replicate
  93. # the KV heads across multiple tensor parallel GPUs.
  94. assert tp_size % self.total_num_kv_heads == 0
  95. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  96. self.query_key_value = QKVParallelLinear(
  97. self.hidden_size,
  98. self.head_dim,
  99. self.total_num_heads,
  100. self.total_num_kv_heads,
  101. bias=config.bias,
  102. skip_bias_add=True,
  103. linear_method=linear_method,
  104. )
  105. self.q_size = self.num_heads * self.head_dim
  106. self.kv_size = self.num_kv_heads * self.head_dim
  107. # Layer-wise attention scaling
  108. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  109. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  110. or config.parallel_attn)
  111. self.dense = RowParallelLinear(
  112. self.hidden_size,
  113. self.hidden_size,
  114. bias=config.bias,
  115. skip_bias_add=True,
  116. linear_method=linear_method,
  117. reduce_results=self.reduce_row_parallel_results)
  118. self.use_rotary = config.rotary
  119. self.use_alibi = config.alibi
  120. assert not (self.use_rotary and self.use_alibi), (
  121. "Rotary and alibi are mutually exclusive.")
  122. if self.use_rotary:
  123. rope_theta = getattr(config, "rope_theta", 10000)
  124. max_position_embeddings = getattr(config,
  125. "max_position_embeddings", 8192)
  126. self.rotary_emb = get_rope(
  127. self.head_dim,
  128. rotary_dim=self.head_dim,
  129. max_position=max_position_embeddings,
  130. base=rope_theta,
  131. )
  132. self.attn = Attention(self.num_heads,
  133. self.head_dim,
  134. self.inv_norm_factor,
  135. num_kv_heads=self.num_kv_heads)
  136. elif self.use_alibi:
  137. tp_rank = get_tensor_model_parallel_rank()
  138. head_start = tp_rank * self.num_heads
  139. head_end = (tp_rank + 1) * self.num_heads
  140. alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
  141. self.inv_norm_factor)
  142. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  143. self.attn = Attention(self.num_heads,
  144. self.head_dim,
  145. self.inv_norm_factor,
  146. num_kv_heads=self.num_kv_heads,
  147. alibi_slopes=alibi_slopes)
  148. else:
  149. self.attn = Attention(self.num_heads,
  150. self.head_dim,
  151. scale=self.inv_norm_factor,
  152. num_kv_heads=self.num_kv_heads)
  153. def forward(
  154. self,
  155. positions: torch.Tensor,
  156. hidden_states: torch.Tensor,
  157. kv_cache: torch.Tensor,
  158. attn_metadata: AttentionMetadata,
  159. ) -> torch.Tensor:
  160. qkv, bias = self.query_key_value(hidden_states)
  161. if bias is not None:
  162. qkv += bias
  163. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  164. if self.use_rotary:
  165. q, k = self.rotary_emb(positions, q, k)
  166. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  167. attn_output, bias = self.dense(attn_output)
  168. return attn_output, bias
  169. class FalconMLP(nn.Module):
  170. def __init__(
  171. self,
  172. config: FalconConfig,
  173. linear_method: Optional[LinearMethodBase] = None,
  174. ):
  175. super().__init__()
  176. hidden_size = config.hidden_size
  177. self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
  178. 4 * hidden_size,
  179. bias=config.bias,
  180. skip_bias_add=True,
  181. linear_method=linear_method)
  182. quant_config = getattr(linear_method, "quant_config", None)
  183. self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
  184. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  185. or config.parallel_attn)
  186. self.dense_4h_to_h = RowParallelLinear(
  187. 4 * hidden_size,
  188. hidden_size,
  189. bias=config.bias,
  190. skip_bias_add=True,
  191. reduce_results=self.reduce_row_parallel_results,
  192. linear_method=linear_method)
  193. def forward(self, x: torch.Tensor) -> torch.Tensor:
  194. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
  195. x, bias = self.dense_h_to_4h(x)
  196. if bias is not None:
  197. x += bias
  198. x = self.act(x)
  199. x, bias = self.dense_4h_to_h(x)
  200. return x, bias
  201. class FalconDecoderLayer(nn.Module):
  202. def __init__(
  203. self,
  204. config: FalconConfig,
  205. linear_method: Optional[LinearMethodBase] = None,
  206. ):
  207. super().__init__()
  208. hidden_size = config.hidden_size
  209. self.num_heads = config.num_attention_heads
  210. self.self_attention = FalconAttention(config, linear_method)
  211. self.mlp = FalconMLP(config, linear_method)
  212. self.config = config
  213. if config.new_decoder_architecture:
  214. # The layer norm before self-attention
  215. self.ln_attn = LayerNorm(hidden_size,
  216. eps=config.layer_norm_epsilon)
  217. # The layer norm before the MLP
  218. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  219. else:
  220. self.input_layernorm = LayerNorm(hidden_size,
  221. eps=config.layer_norm_epsilon)
  222. if not config.parallel_attn:
  223. self.post_attention_layernorm = LayerNorm(
  224. hidden_size, eps=config.layer_norm_epsilon)
  225. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  226. or config.parallel_attn)
  227. def forward(
  228. self,
  229. positions: torch.Tensor,
  230. hidden_states: torch.Tensor,
  231. kv_cache: torch.Tensor,
  232. attn_metadata: AttentionMetadata,
  233. ) -> torch.Tensor:
  234. residual = hidden_states
  235. if self.config.new_decoder_architecture:
  236. attention_layernorm_out = self.ln_attn(hidden_states)
  237. mlp_layernorm_out = self.ln_mlp(hidden_states)
  238. else:
  239. attention_layernorm_out = self.input_layernorm(hidden_states)
  240. # Self attention.
  241. attention_output, attention_bias = self.self_attention(
  242. positions=positions,
  243. hidden_states=attention_layernorm_out,
  244. kv_cache=kv_cache,
  245. attn_metadata=attn_metadata,
  246. )
  247. if self.reduce_row_parallel_results and attention_bias is not None:
  248. attention_output += attention_bias
  249. if not self.config.new_decoder_architecture:
  250. if self.config.parallel_attn:
  251. mlp_layernorm_out = attention_layernorm_out
  252. else:
  253. residual += attention_output
  254. mlp_layernorm_out = self.post_attention_layernorm(residual)
  255. # MLP.
  256. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
  257. if self.reduce_row_parallel_results and mlp_bias is not None:
  258. mlp_output += mlp_bias
  259. if not self.reduce_row_parallel_results:
  260. # When MLP and Attention layers are parallel, we can use
  261. # only one all-reduce operator to reduce the results from
  262. # both MLP and Attention layers.
  263. mlp_output += attention_output
  264. mlp_output = tensor_model_parallel_all_reduce(mlp_output)
  265. if attention_bias is not None:
  266. mlp_output += attention_bias
  267. if mlp_bias is not None:
  268. mlp_output += mlp_bias
  269. output = mlp_output + residual
  270. return output
  271. class FalconModel(nn.Module):
  272. def __init__(
  273. self,
  274. config: FalconConfig,
  275. linear_method: Optional[LinearMethodBase] = None,
  276. ):
  277. super().__init__()
  278. self.config = config
  279. self.embed_dim = config.hidden_size
  280. self.num_heads = config.num_attention_heads
  281. self.use_alibi = config.alibi
  282. # Embedding + LN Embedding
  283. self.word_embeddings = VocabParallelEmbedding(
  284. config.vocab_size, self.embed_dim, linear_method=linear_method)
  285. # Transformer blocks
  286. self.h = nn.ModuleList([
  287. FalconDecoderLayer(config, linear_method)
  288. for _ in range(config.num_hidden_layers)
  289. ])
  290. # Final Layer Norm
  291. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  292. def forward(
  293. self,
  294. input_ids: torch.LongTensor,
  295. positions: torch.Tensor,
  296. kv_caches: List[torch.Tensor],
  297. attn_metadata: AttentionMetadata,
  298. ) -> torch.Tensor:
  299. hidden_states = self.word_embeddings(input_ids)
  300. for i in range(len(self.h)):
  301. layer = self.h[i]
  302. hidden_states = layer(
  303. positions,
  304. hidden_states,
  305. kv_caches[i],
  306. attn_metadata,
  307. )
  308. hidden_states = self.ln_f(hidden_states)
  309. return hidden_states
  310. class FalconForCausalLM(nn.Module):
  311. def __init__(
  312. self,
  313. config: FalconConfig,
  314. linear_method: Optional[LinearMethodBase] = None,
  315. ):
  316. super().__init__()
  317. self.config = config
  318. self.linear_method = linear_method
  319. self.transformer = FalconModel(config, linear_method)
  320. self.lm_head = ParallelLMHead(config.vocab_size,
  321. config.hidden_size,
  322. linear_method=linear_method)
  323. self.logits_processor = LogitsProcessor(config.vocab_size)
  324. self.sampler = Sampler()
  325. def forward(
  326. self,
  327. input_ids: torch.LongTensor,
  328. positions: torch.Tensor,
  329. kv_caches: List[torch.Tensor],
  330. attn_metadata: AttentionMetadata,
  331. ) -> torch.Tensor:
  332. hidden_states = self.transformer(
  333. input_ids,
  334. positions,
  335. kv_caches,
  336. attn_metadata,
  337. )
  338. return hidden_states
  339. def compute_logits(self, hidden_states: torch.Tensor,
  340. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  341. logits = self.logits_processor(self.lm_head, hidden_states,
  342. sampling_metadata)
  343. return logits
  344. def sample(
  345. self,
  346. logits: torch.Tensor,
  347. sampling_metadata: SamplingMetadata,
  348. ) -> Optional[SamplerOutput]:
  349. next_tokens = self.sampler(logits, sampling_metadata)
  350. return next_tokens
  351. def load_weights(self,
  352. model_name_or_path: str,
  353. cache_dir: Optional[str] = None,
  354. load_format: str = "auto",
  355. revision: Optional[str] = None):
  356. total_num_heads = self.config.num_attention_heads
  357. if self.config.new_decoder_architecture:
  358. total_num_kv_heads = self.config.num_kv_heads
  359. elif self.config.multi_query:
  360. total_num_kv_heads = 1
  361. else:
  362. total_num_kv_heads = total_num_heads
  363. num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
  364. params_dict = dict(self.named_parameters(remove_duplicate=False))
  365. for name, loaded_weight in hf_model_weights_iterator(
  366. model_name_or_path, cache_dir, load_format, revision,
  367. self.config):
  368. if name == "lm_head.weight":
  369. # Falcon uses tied embeddings.
  370. continue
  371. # Skip loading extra bias for GPTQ models.
  372. if name.endswith(".bias") and name not in params_dict:
  373. continue
  374. param = params_dict[name]
  375. if "query_key_value" in name:
  376. output_dim = getattr(param, "output_dim", None)
  377. loaded_weight_shape = loaded_weight.shape
  378. if output_dim is not None:
  379. loaded_weight = loaded_weight.view(
  380. loaded_weight_shape[:output_dim] +
  381. (total_num_kv_heads, num_query_heads_per_kv_head + 2,
  382. -1) + loaded_weight_shape[output_dim + 1:])
  383. wq = loaded_weight.narrow(
  384. output_dim + 1, 0,
  385. num_query_heads_per_kv_head).reshape(
  386. *loaded_weight_shape[:output_dim], -1,
  387. *loaded_weight_shape[output_dim + 1:])
  388. wk = loaded_weight.narrow(
  389. output_dim + 1, num_query_heads_per_kv_head,
  390. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  391. *loaded_weight_shape[output_dim + 1:])
  392. wv = loaded_weight.narrow(
  393. output_dim + 1, num_query_heads_per_kv_head + 1,
  394. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  395. *loaded_weight_shape[output_dim + 1:])
  396. loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
  397. weight_loader = getattr(param, "weight_loader",
  398. default_weight_loader)
  399. weight_loader(param, loaded_weight)