falcon.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
  7. # reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Falcon model."""
  21. import math
  22. from typing import List, Optional, Tuple, Union
  23. import torch
  24. from torch import nn
  25. from torch.nn import LayerNorm
  26. from transformers import FalconConfig as HF_FalconConfig
  27. from aphrodite.modeling.metadata import InputMetadata
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.attention import PagedAttention
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. LinearMethodBase,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.rotary_embedding import get_rope
  35. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding, ParallelLMHead)
  38. from aphrodite.modeling.megatron.communication_op import (
  39. tensor_model_parallel_all_reduce)
  40. from aphrodite.modeling.megatron.parallel_state import (
  41. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  44. hf_model_weights_iterator)
  45. from aphrodite.common.sequence import SamplerOutput
  46. from aphrodite.transformers_utils.configs import RWConfig
  47. KVCache = Tuple[torch.Tensor, torch.Tensor]
  48. FalconConfig = Union[HF_FalconConfig, RWConfig]
  49. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  50. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  51. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  52. dtype=torch.float32)
  53. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  54. slopes = torch.pow(base, powers)
  55. if closest_power_of_2 != total_num_heads:
  56. extra_base = torch.tensor(
  57. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  58. dtype=torch.float32)
  59. num_remaining_heads = min(closest_power_of_2,
  60. total_num_heads - closest_power_of_2)
  61. extra_powers = torch.arange(1,
  62. 1 + 2 * num_remaining_heads,
  63. 2,
  64. dtype=torch.int32)
  65. slopes = torch.cat(
  66. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  67. return slopes
  68. class FalconAttention(nn.Module):
  69. def __init__(
  70. self,
  71. config: FalconConfig,
  72. linear_method: Optional[LinearMethodBase] = None,
  73. ):
  74. super().__init__()
  75. self.hidden_size = config.hidden_size
  76. tp_size = get_tensor_model_parallel_world_size()
  77. self.total_num_heads = config.num_attention_heads
  78. assert self.total_num_heads % tp_size == 0
  79. self.num_heads = self.total_num_heads // tp_size
  80. self.head_dim = self.hidden_size // self.total_num_heads
  81. assert self.head_dim * self.total_num_heads == self.hidden_size
  82. self.new_decoder_architecture = config.new_decoder_architecture
  83. self.multi_query = config.multi_query
  84. if self.new_decoder_architecture:
  85. self.total_num_kv_heads = config.num_kv_heads
  86. elif self.multi_query:
  87. self.total_num_kv_heads = 1
  88. else:
  89. self.total_num_kv_heads = self.total_num_heads
  90. if self.total_num_kv_heads >= tp_size:
  91. # Number of KV heads is greater than TP size, so we partition
  92. # the KV heads across multiple tensor parallel GPUs.
  93. assert self.total_num_kv_heads % tp_size == 0
  94. else:
  95. # Number of KV heads is less than TP size, so we replicate
  96. # the KV heads across multiple tensor parallel GPUs.
  97. assert tp_size % self.total_num_kv_heads == 0
  98. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  99. self.query_key_value = QKVParallelLinear(
  100. self.hidden_size,
  101. self.head_dim,
  102. self.total_num_heads,
  103. self.total_num_kv_heads,
  104. bias=config.bias,
  105. skip_bias_add=True,
  106. linear_method=linear_method,
  107. )
  108. self.q_size = self.num_heads * self.head_dim
  109. self.kv_size = self.num_kv_heads * self.head_dim
  110. # Layer-wise attention scaling
  111. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  112. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  113. or config.parallel_attn)
  114. self.dense = RowParallelLinear(
  115. self.hidden_size,
  116. self.hidden_size,
  117. bias=config.bias,
  118. skip_bias_add=True,
  119. linear_method=linear_method,
  120. reduce_results=self.reduce_row_parallel_results)
  121. self.use_rotary = config.rotary
  122. self.use_alibi = config.alibi
  123. assert not (self.use_rotary and self.use_alibi), (
  124. "Rotary and alibi are mutually exclusive.")
  125. if self.use_rotary:
  126. rope_theta = getattr(config, "rope_theta", 10000)
  127. max_position_embeddings = getattr(config,
  128. "max_position_embeddings", 8192)
  129. self.rotary_emb = get_rope(
  130. self.head_dim,
  131. rotary_dim=self.head_dim,
  132. max_position=max_position_embeddings,
  133. base=rope_theta,
  134. )
  135. self.attn = PagedAttention(self.num_heads,
  136. self.head_dim,
  137. self.inv_norm_factor,
  138. num_kv_heads=self.num_kv_heads)
  139. elif self.use_alibi:
  140. tp_rank = get_tensor_model_parallel_rank()
  141. head_start = tp_rank * self.num_heads
  142. head_end = (tp_rank + 1) * self.num_heads
  143. alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
  144. self.inv_norm_factor)
  145. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  146. self.attn = PagedAttention(self.num_heads,
  147. self.head_dim,
  148. self.inv_norm_factor,
  149. num_kv_heads=self.num_kv_heads,
  150. alibi_slopes=alibi_slopes)
  151. else:
  152. self.attn = PagedAttention(self.num_heads,
  153. self.head_dim,
  154. scale=self.inv_norm_factor,
  155. num_kv_heads=self.num_kv_heads)
  156. def forward(
  157. self,
  158. positions: torch.Tensor,
  159. hidden_states: torch.Tensor,
  160. kv_cache: KVCache,
  161. input_metadata: InputMetadata,
  162. ) -> torch.Tensor:
  163. qkv, bias = self.query_key_value(hidden_states)
  164. if bias is not None:
  165. qkv += bias
  166. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  167. if self.use_rotary:
  168. q, k = self.rotary_emb(positions, q, k)
  169. k_cache, v_cache = kv_cache
  170. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  171. attn_output, bias = self.dense(attn_output)
  172. return attn_output, bias
  173. class FalconMLP(nn.Module):
  174. def __init__(
  175. self,
  176. config: FalconConfig,
  177. linear_method: Optional[LinearMethodBase] = None,
  178. ):
  179. super().__init__()
  180. hidden_size = config.hidden_size
  181. self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
  182. 4 * hidden_size,
  183. bias=config.bias,
  184. skip_bias_add=True,
  185. linear_method=linear_method)
  186. quant_config = getattr(linear_method, "quant_config", None)
  187. self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
  188. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  189. or config.parallel_attn)
  190. self.dense_4h_to_h = RowParallelLinear(
  191. 4 * hidden_size,
  192. hidden_size,
  193. bias=config.bias,
  194. skip_bias_add=True,
  195. reduce_results=self.reduce_row_parallel_results,
  196. linear_method=linear_method)
  197. def forward(self, x: torch.Tensor) -> torch.Tensor:
  198. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
  199. x, bias = self.dense_h_to_4h(x)
  200. if bias is not None:
  201. x += bias
  202. x = self.act(x)
  203. x, bias = self.dense_4h_to_h(x)
  204. return x, bias
  205. class FalconDecoderLayer(nn.Module):
  206. def __init__(
  207. self,
  208. config: FalconConfig,
  209. linear_method: Optional[LinearMethodBase] = None,
  210. ):
  211. super().__init__()
  212. hidden_size = config.hidden_size
  213. self.num_heads = config.num_attention_heads
  214. self.self_attention = FalconAttention(config, linear_method)
  215. self.mlp = FalconMLP(config, linear_method)
  216. self.config = config
  217. if config.new_decoder_architecture:
  218. # The layer norm before self-attention
  219. self.ln_attn = LayerNorm(hidden_size,
  220. eps=config.layer_norm_epsilon)
  221. # The layer norm before the MLP
  222. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  223. else:
  224. self.input_layernorm = LayerNorm(hidden_size,
  225. eps=config.layer_norm_epsilon)
  226. if not config.parallel_attn:
  227. self.post_attention_layernorm = LayerNorm(
  228. hidden_size, eps=config.layer_norm_epsilon)
  229. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  230. or config.parallel_attn)
  231. def forward(
  232. self,
  233. positions: torch.Tensor,
  234. hidden_states: torch.Tensor,
  235. kv_cache: KVCache,
  236. input_metadata: InputMetadata,
  237. ) -> torch.Tensor:
  238. residual = hidden_states
  239. if self.config.new_decoder_architecture:
  240. attention_layernorm_out = self.ln_attn(hidden_states)
  241. mlp_layernorm_out = self.ln_mlp(hidden_states)
  242. else:
  243. attention_layernorm_out = self.input_layernorm(hidden_states)
  244. # Self attention.
  245. attention_output, attention_bias = self.self_attention(
  246. positions=positions,
  247. hidden_states=attention_layernorm_out,
  248. kv_cache=kv_cache,
  249. input_metadata=input_metadata,
  250. )
  251. if self.reduce_row_parallel_results and attention_bias is not None:
  252. attention_output += attention_bias
  253. if not self.config.new_decoder_architecture:
  254. if self.config.parallel_attn:
  255. mlp_layernorm_out = attention_layernorm_out
  256. else:
  257. residual += attention_output
  258. mlp_layernorm_out = self.post_attention_layernorm(residual)
  259. # MLP.
  260. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
  261. if self.reduce_row_parallel_results and mlp_bias is not None:
  262. mlp_output += mlp_bias
  263. if not self.reduce_row_parallel_results:
  264. # When MLP and Attention layers are parallel, we can use
  265. # only one all-reduce operator to reduce the results from
  266. # both MLP and Attention layers.
  267. mlp_output += attention_output
  268. mlp_output = tensor_model_parallel_all_reduce(mlp_output)
  269. if attention_bias is not None:
  270. mlp_output += attention_bias
  271. if mlp_bias is not None:
  272. mlp_output += mlp_bias
  273. output = mlp_output + residual
  274. return output
  275. class FalconModel(nn.Module):
  276. def __init__(
  277. self,
  278. config: FalconConfig,
  279. linear_method: Optional[LinearMethodBase] = None,
  280. ):
  281. super().__init__()
  282. self.config = config
  283. self.embed_dim = config.hidden_size
  284. self.num_heads = config.num_attention_heads
  285. self.use_alibi = config.alibi
  286. # Embedding + LN Embedding
  287. self.word_embeddings = VocabParallelEmbedding(
  288. config.vocab_size, self.embed_dim, linear_method=linear_method)
  289. # Transformer blocks
  290. self.h = nn.ModuleList([
  291. FalconDecoderLayer(config, linear_method)
  292. for _ in range(config.num_hidden_layers)
  293. ])
  294. # Final Layer Norm
  295. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  296. def forward(
  297. self,
  298. input_ids: torch.LongTensor,
  299. positions: torch.Tensor,
  300. kv_caches: List[KVCache],
  301. input_metadata: InputMetadata,
  302. ) -> torch.Tensor:
  303. hidden_states = self.word_embeddings(input_ids)
  304. for i in range(len(self.h)):
  305. layer = self.h[i]
  306. hidden_states = layer(
  307. positions,
  308. hidden_states,
  309. kv_caches[i],
  310. input_metadata,
  311. )
  312. hidden_states = self.ln_f(hidden_states)
  313. return hidden_states
  314. class FalconForCausalLM(nn.Module):
  315. def __init__(
  316. self,
  317. config: FalconConfig,
  318. linear_method: Optional[LinearMethodBase] = None,
  319. ):
  320. super().__init__()
  321. self.config = config
  322. self.linear_method = linear_method
  323. self.transformer = FalconModel(config, linear_method)
  324. self.lm_head = ParallelLMHead(config.vocab_size,
  325. config.hidden_size,
  326. linear_method=linear_method)
  327. self.sampler = Sampler(config.vocab_size)
  328. self.quant_sampler = QuantSampler(config.vocab_size)
  329. def forward(
  330. self,
  331. input_ids: torch.LongTensor,
  332. positions: torch.Tensor,
  333. kv_caches: List[KVCache],
  334. input_metadata: InputMetadata,
  335. ) -> torch.Tensor:
  336. hidden_states = self.transformer(
  337. input_ids,
  338. positions,
  339. kv_caches,
  340. input_metadata,
  341. )
  342. return hidden_states
  343. def sample(
  344. self,
  345. hidden_states: torch.Tensor,
  346. sampling_metadata: SamplingMetadata,
  347. ) -> Optional[SamplerOutput]:
  348. if (self.linear_method is not None
  349. and not self.linear_method.quant_config.merge_weight()):
  350. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  351. sampling_metadata)
  352. else:
  353. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  354. sampling_metadata)
  355. return next_tokens
  356. def load_weights(self,
  357. model_name_or_path: str,
  358. cache_dir: Optional[str] = None,
  359. load_format: str = "auto",
  360. revision: Optional[str] = None):
  361. total_num_heads = self.config.num_attention_heads
  362. if self.config.new_decoder_architecture:
  363. total_num_kv_heads = self.config.num_kv_heads
  364. elif self.config.multi_query:
  365. total_num_kv_heads = 1
  366. else:
  367. total_num_kv_heads = total_num_heads
  368. num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
  369. params_dict = dict(self.named_parameters())
  370. for name, loaded_weight in hf_model_weights_iterator(
  371. model_name_or_path, cache_dir, load_format, revision,
  372. self.config):
  373. # Skip loading extra bias for GPTQ models.
  374. if name.endswith(".bias") and name not in params_dict:
  375. continue
  376. param = params_dict[name]
  377. if "query_key_value" in name:
  378. output_dim = getattr(param, "output_dim", None)
  379. loaded_weight_shape = loaded_weight.shape
  380. if output_dim is not None:
  381. loaded_weight = loaded_weight.view(
  382. loaded_weight_shape[:output_dim] +
  383. (total_num_kv_heads, num_query_heads_per_kv_head + 2,
  384. -1) + loaded_weight_shape[output_dim + 1:])
  385. wq = loaded_weight.narrow(
  386. output_dim + 1, 0,
  387. num_query_heads_per_kv_head).reshape(
  388. *loaded_weight_shape[:output_dim], -1,
  389. *loaded_weight_shape[output_dim + 1:])
  390. wk = loaded_weight.narrow(
  391. output_dim + 1, num_query_heads_per_kv_head,
  392. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  393. *loaded_weight_shape[output_dim + 1:])
  394. wv = loaded_weight.narrow(
  395. output_dim + 1, num_query_heads_per_kv_head + 1,
  396. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  397. *loaded_weight_shape[output_dim + 1:])
  398. loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
  399. weight_loader = getattr(param, "weight_loader",
  400. default_weight_loader)
  401. weight_loader(param, loaded_weight)