falcon.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch Falcon model."""
  20. import math
  21. from typing import Iterable, List, Optional, Tuple, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import LayerNorm
  25. from transformers import FalconConfig as HF_FalconConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import CacheConfig
  28. from aphrodite.common.sequence import IntermediateTensors
  29. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  30. get_tensor_model_parallel_world_size,
  31. tensor_model_parallel_all_reduce)
  32. from aphrodite.modeling.layers.activation import get_act_fn
  33. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.quantization.base_config import QuantizationConfig
  44. from aphrodite.transformers_utils.configs import RWConfig
  45. FalconConfig = Union[HF_FalconConfig, RWConfig]
  46. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  47. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  48. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  49. dtype=torch.float32)
  50. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  51. slopes = torch.pow(base, powers)
  52. if closest_power_of_2 != total_num_heads:
  53. extra_base = torch.tensor(
  54. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  55. dtype=torch.float32)
  56. num_remaining_heads = min(closest_power_of_2,
  57. total_num_heads - closest_power_of_2)
  58. extra_powers = torch.arange(1,
  59. 1 + 2 * num_remaining_heads,
  60. 2,
  61. dtype=torch.int32)
  62. slopes = torch.cat(
  63. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  64. return slopes
  65. class FalconAttention(nn.Module):
  66. def __init__(
  67. self,
  68. config: FalconConfig,
  69. cache_config: Optional[CacheConfig] = None,
  70. quant_config: Optional[QuantizationConfig] = None,
  71. ):
  72. super().__init__()
  73. self.hidden_size = config.hidden_size
  74. tp_size = get_tensor_model_parallel_world_size()
  75. self.total_num_heads = config.num_attention_heads
  76. assert self.total_num_heads % tp_size == 0
  77. self.num_heads = self.total_num_heads // tp_size
  78. self.head_dim = self.hidden_size // self.total_num_heads
  79. assert self.head_dim * self.total_num_heads == self.hidden_size
  80. self.new_decoder_architecture = config.new_decoder_architecture
  81. self.multi_query = config.multi_query
  82. if self.new_decoder_architecture:
  83. self.total_num_kv_heads = config.num_kv_heads
  84. elif self.multi_query:
  85. self.total_num_kv_heads = 1
  86. else:
  87. self.total_num_kv_heads = self.total_num_heads
  88. if self.total_num_kv_heads >= tp_size:
  89. # Number of KV heads is greater than TP size, so we partition
  90. # the KV heads across multiple tensor parallel GPUs.
  91. assert self.total_num_kv_heads % tp_size == 0
  92. else:
  93. # Number of KV heads is less than TP size, so we replicate
  94. # the KV heads across multiple tensor parallel GPUs.
  95. assert tp_size % self.total_num_kv_heads == 0
  96. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  97. self.query_key_value = QKVParallelLinear(
  98. self.hidden_size,
  99. self.head_dim,
  100. self.total_num_heads,
  101. self.total_num_kv_heads,
  102. bias=config.bias,
  103. skip_bias_add=True,
  104. quant_config=quant_config,
  105. )
  106. self.q_size = self.num_heads * self.head_dim
  107. self.kv_size = self.num_kv_heads * self.head_dim
  108. # Layer-wise attention scaling
  109. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  110. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  111. or config.parallel_attn)
  112. self.dense = RowParallelLinear(
  113. self.hidden_size,
  114. self.hidden_size,
  115. bias=config.bias,
  116. skip_bias_add=True,
  117. quant_config=quant_config,
  118. reduce_results=self.reduce_row_parallel_results)
  119. self.use_rotary = config.rotary
  120. self.use_alibi = config.alibi
  121. assert not (self.use_rotary and self.use_alibi), (
  122. "Rotary and alibi are mutually exclusive.")
  123. if self.use_rotary:
  124. rope_theta = getattr(config, "rope_theta", 10000)
  125. max_position_embeddings = getattr(config,
  126. "max_position_embeddings", 8192)
  127. self.rotary_emb = get_rope(
  128. self.head_dim,
  129. rotary_dim=self.head_dim,
  130. max_position=max_position_embeddings,
  131. base=rope_theta,
  132. )
  133. self.attn = Attention(self.num_heads,
  134. self.head_dim,
  135. self.inv_norm_factor,
  136. num_kv_heads=self.num_kv_heads,
  137. quant_config=quant_config)
  138. elif self.use_alibi:
  139. tp_rank = get_tensor_model_parallel_rank()
  140. head_start = tp_rank * self.num_heads
  141. head_end = (tp_rank + 1) * self.num_heads
  142. alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
  143. self.inv_norm_factor)
  144. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  145. self.attn = Attention(self.num_heads,
  146. self.head_dim,
  147. self.inv_norm_factor,
  148. num_kv_heads=self.num_kv_heads,
  149. alibi_slopes=alibi_slopes,
  150. quant_config=quant_config)
  151. else:
  152. self.attn = Attention(self.num_heads,
  153. self.head_dim,
  154. scale=self.inv_norm_factor,
  155. num_kv_heads=self.num_kv_heads,
  156. cache_config=cache_config,
  157. quant_config=quant_config)
  158. def forward(
  159. self,
  160. positions: torch.Tensor,
  161. hidden_states: torch.Tensor,
  162. kv_cache: torch.Tensor,
  163. attn_metadata: AttentionMetadata,
  164. ) -> torch.Tensor:
  165. qkv, bias = self.query_key_value(hidden_states)
  166. if bias is not None:
  167. qkv += bias
  168. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  169. if self.use_rotary:
  170. q, k = self.rotary_emb(positions, q, k)
  171. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  172. attn_output, bias = self.dense(attn_output)
  173. return attn_output, bias
  174. class FalconMLP(nn.Module):
  175. def __init__(
  176. self,
  177. config: FalconConfig,
  178. quant_config: Optional[QuantizationConfig] = None,
  179. ):
  180. super().__init__()
  181. hidden_size = config.hidden_size
  182. self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
  183. 4 * hidden_size,
  184. bias=config.bias,
  185. skip_bias_add=True,
  186. quant_config=quant_config)
  187. self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
  188. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  189. or config.parallel_attn)
  190. self.dense_4h_to_h = RowParallelLinear(
  191. 4 * hidden_size,
  192. hidden_size,
  193. bias=config.bias,
  194. skip_bias_add=True,
  195. reduce_results=self.reduce_row_parallel_results,
  196. quant_config=quant_config)
  197. def forward(self, x: torch.Tensor) -> torch.Tensor:
  198. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
  199. x, bias = self.dense_h_to_4h(x)
  200. if bias is not None:
  201. x += bias
  202. x = self.act(x)
  203. x, bias = self.dense_4h_to_h(x)
  204. return x, bias
  205. class FalconDecoderLayer(nn.Module):
  206. def __init__(
  207. self,
  208. config: FalconConfig,
  209. cache_config: Optional[CacheConfig] = None,
  210. quant_config: Optional[QuantizationConfig] = None,
  211. ):
  212. super().__init__()
  213. hidden_size = config.hidden_size
  214. self.num_heads = config.num_attention_heads
  215. self.self_attention = FalconAttention(config, cache_config,
  216. quant_config)
  217. self.mlp = FalconMLP(config, quant_config)
  218. self.config = config
  219. if config.new_decoder_architecture:
  220. # The layer norm before self-attention
  221. self.ln_attn = LayerNorm(hidden_size,
  222. eps=config.layer_norm_epsilon)
  223. # The layer norm before the MLP
  224. self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  225. else:
  226. self.input_layernorm = LayerNorm(hidden_size,
  227. eps=config.layer_norm_epsilon)
  228. if not config.parallel_attn:
  229. self.post_attention_layernorm = LayerNorm(
  230. hidden_size, eps=config.layer_norm_epsilon)
  231. self.reduce_row_parallel_results = not (config.new_decoder_architecture
  232. or config.parallel_attn)
  233. def forward(
  234. self,
  235. positions: torch.Tensor,
  236. hidden_states: torch.Tensor,
  237. kv_cache: torch.Tensor,
  238. attn_metadata: AttentionMetadata,
  239. ) -> torch.Tensor:
  240. residual = hidden_states
  241. if self.config.new_decoder_architecture:
  242. attention_layernorm_out = self.ln_attn(hidden_states)
  243. mlp_layernorm_out = self.ln_mlp(hidden_states)
  244. else:
  245. attention_layernorm_out = self.input_layernorm(hidden_states)
  246. # Self attention.
  247. attention_output, attention_bias = self.self_attention(
  248. positions=positions,
  249. hidden_states=attention_layernorm_out,
  250. kv_cache=kv_cache,
  251. attn_metadata=attn_metadata,
  252. )
  253. if self.reduce_row_parallel_results and attention_bias is not None:
  254. attention_output += attention_bias
  255. if not self.config.new_decoder_architecture:
  256. if self.config.parallel_attn:
  257. mlp_layernorm_out = attention_layernorm_out
  258. else:
  259. residual += attention_output
  260. mlp_layernorm_out = self.post_attention_layernorm(residual)
  261. # MLP.
  262. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
  263. if self.reduce_row_parallel_results and mlp_bias is not None:
  264. mlp_output += mlp_bias
  265. if not self.reduce_row_parallel_results:
  266. # When MLP and Attention layers are parallel, we can use
  267. # only one all-reduce operator to reduce the results from
  268. # both MLP and Attention layers.
  269. mlp_output += attention_output
  270. mlp_output = tensor_model_parallel_all_reduce(mlp_output)
  271. if attention_bias is not None:
  272. mlp_output += attention_bias
  273. if mlp_bias is not None:
  274. mlp_output += mlp_bias
  275. output = mlp_output + residual
  276. return output
  277. class FalconModel(nn.Module):
  278. def __init__(
  279. self,
  280. config: FalconConfig,
  281. cache_config: Optional[CacheConfig] = None,
  282. quant_config: Optional[QuantizationConfig] = None,
  283. ):
  284. super().__init__()
  285. self.config = config
  286. self.embed_dim = config.hidden_size
  287. self.num_heads = config.num_attention_heads
  288. self.use_alibi = config.alibi
  289. # Embedding + LN Embedding
  290. self.word_embeddings = VocabParallelEmbedding(
  291. config.vocab_size,
  292. self.embed_dim,
  293. )
  294. # Transformer blocks
  295. self.h = nn.ModuleList([
  296. FalconDecoderLayer(config, cache_config, quant_config)
  297. for _ in range(config.num_hidden_layers)
  298. ])
  299. # Final Layer Norm
  300. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  301. def forward(
  302. self,
  303. input_ids: torch.LongTensor,
  304. positions: torch.Tensor,
  305. kv_caches: List[torch.Tensor],
  306. attn_metadata: AttentionMetadata,
  307. ) -> torch.Tensor:
  308. hidden_states = self.word_embeddings(input_ids)
  309. for i in range(len(self.h)):
  310. layer = self.h[i]
  311. hidden_states = layer(
  312. positions,
  313. hidden_states,
  314. kv_caches[i],
  315. attn_metadata,
  316. )
  317. hidden_states = self.ln_f(hidden_states)
  318. return hidden_states
  319. class FalconForCausalLM(nn.Module):
  320. def __init__(
  321. self,
  322. config: FalconConfig,
  323. cache_config: Optional[CacheConfig] = None,
  324. quant_config: Optional[QuantizationConfig] = None,
  325. ):
  326. super().__init__()
  327. self.config = config
  328. self.quant_config = quant_config
  329. self.transformer = FalconModel(config, cache_config, quant_config)
  330. self.lm_head = self.transformer.word_embeddings
  331. self.logits_processor = LogitsProcessor(config.vocab_size)
  332. self.sampler = Sampler()
  333. def forward(
  334. self,
  335. input_ids: torch.LongTensor,
  336. positions: torch.Tensor,
  337. kv_caches: List[torch.Tensor],
  338. attn_metadata: AttentionMetadata,
  339. intermediate_tensors: Optional[IntermediateTensors] = None,
  340. ) -> torch.Tensor:
  341. hidden_states = self.transformer(
  342. input_ids,
  343. positions,
  344. kv_caches,
  345. attn_metadata,
  346. )
  347. return hidden_states
  348. def compute_logits(
  349. self,
  350. hidden_states: torch.Tensor,
  351. sampling_metadata: SamplingMetadata,
  352. ) -> Optional[torch.Tensor]:
  353. logits = self.logits_processor(self.lm_head, hidden_states,
  354. sampling_metadata)
  355. return logits
  356. def sample(
  357. self,
  358. logits: torch.Tensor,
  359. sampling_metadata: SamplingMetadata,
  360. ) -> Optional[SamplerOutput]:
  361. next_tokens = self.sampler(logits, sampling_metadata)
  362. return next_tokens
  363. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  364. total_num_heads = self.config.num_attention_heads
  365. if self.config.new_decoder_architecture:
  366. total_num_kv_heads = self.config.num_kv_heads
  367. elif self.config.multi_query:
  368. total_num_kv_heads = 1
  369. else:
  370. total_num_kv_heads = total_num_heads
  371. num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
  372. params_dict = dict(self.named_parameters(remove_duplicate=False))
  373. for name, loaded_weight in weights:
  374. if name == "lm_head.weight":
  375. # Falcon uses tied embeddings.
  376. continue
  377. # Skip loading extra bias for GPTQ models.
  378. if name.endswith(".bias") and name not in params_dict:
  379. continue
  380. param = params_dict[name]
  381. if "query_key_value" in name:
  382. output_dim = getattr(param, "output_dim", None)
  383. loaded_weight_shape = loaded_weight.shape
  384. if output_dim is not None:
  385. loaded_weight = loaded_weight.view(
  386. loaded_weight_shape[:output_dim] +
  387. (total_num_kv_heads, num_query_heads_per_kv_head + 2,
  388. -1) + loaded_weight_shape[output_dim + 1:])
  389. wq = loaded_weight.narrow(
  390. output_dim + 1, 0,
  391. num_query_heads_per_kv_head).reshape(
  392. *loaded_weight_shape[:output_dim], -1,
  393. *loaded_weight_shape[output_dim + 1:])
  394. wk = loaded_weight.narrow(
  395. output_dim + 1, num_query_heads_per_kv_head,
  396. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  397. *loaded_weight_shape[output_dim + 1:])
  398. wv = loaded_weight.narrow(
  399. output_dim + 1, num_query_heads_per_kv_head + 1,
  400. 1).reshape(*loaded_weight_shape[:output_dim], -1,
  401. *loaded_weight_shape[output_dim + 1:])
  402. loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
  403. weight_loader = getattr(param, "weight_loader",
  404. default_weight_loader)
  405. weight_loader(param, loaded_weight)