bloom.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The CacheFlow team.
  6. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only BLOOM model compatible with HuggingFace weights."""
  20. import math
  21. from typing import List, Optional
  22. import torch
  23. from torch import nn
  24. from transformers import BloomConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  28. LinearMethodBase,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  36. get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  39. hf_model_weights_iterator)
  40. from aphrodite.common.sequence import SamplerOutput
  41. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  42. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  43. base = torch.tensor(
  44. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  45. dtype=torch.float32,
  46. )
  47. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  48. slopes = torch.pow(base, powers)
  49. if closest_power_of_2 != total_num_heads:
  50. extra_base = torch.tensor(
  51. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  52. dtype=torch.float32,
  53. )
  54. num_remaining_heads = min(closest_power_of_2,
  55. total_num_heads - closest_power_of_2)
  56. extra_powers = torch.arange(start=1,
  57. end=1 + 2 * num_remaining_heads,
  58. step=2,
  59. dtype=torch.int32)
  60. slopes = torch.cat(
  61. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  62. return slopes
  63. class BloomAttention(nn.Module):
  64. def __init__(
  65. self,
  66. config: BloomConfig,
  67. linear_method: Optional[LinearMethodBase] = None,
  68. ):
  69. super().__init__()
  70. self.hidden_size = config.hidden_size
  71. self.total_num_heads = config.n_head
  72. self.head_dim = self.hidden_size // self.total_num_heads
  73. assert self.head_dim * self.total_num_heads == self.hidden_size
  74. tp_world_size = get_tensor_model_parallel_world_size()
  75. assert self.total_num_heads % tp_world_size == 0
  76. self.num_heads = self.total_num_heads // tp_world_size
  77. self.query_key_value = QKVParallelLinear(
  78. self.hidden_size,
  79. self.head_dim,
  80. self.total_num_heads,
  81. bias=True,
  82. linear_method=linear_method,
  83. )
  84. self.dense = RowParallelLinear(
  85. self.hidden_size,
  86. self.hidden_size,
  87. bias=True,
  88. linear_method=linear_method,
  89. )
  90. # Create the alibi slopes and slice them.
  91. tp_rank = get_tensor_model_parallel_rank()
  92. head_start = tp_rank * self.num_heads
  93. head_end = (tp_rank + 1) * self.num_heads
  94. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  95. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  96. scaling = self.head_dim**-0.5
  97. self.attn = Attention(self.num_heads,
  98. self.head_dim,
  99. scaling,
  100. alibi_slopes=alibi_slopes)
  101. def forward(
  102. self,
  103. position_ids: torch.Tensor,
  104. hidden_states: torch.Tensor,
  105. kv_cache: torch.Tensor,
  106. attn_metadata: AttentionMetadata,
  107. ) -> torch.Tensor:
  108. del position_ids # Unused.
  109. qkv, _ = self.query_key_value(hidden_states)
  110. q, k, v = qkv.chunk(chunks=3, dim=-1)
  111. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  112. output, _ = self.dense(attn_output)
  113. return output
  114. class BloomMLP(nn.Module):
  115. def __init__(
  116. self,
  117. config: BloomConfig,
  118. linear_method: Optional[LinearMethodBase] = None,
  119. ):
  120. super().__init__()
  121. hidden_size = config.hidden_size
  122. self.dense_h_to_4h = ColumnParallelLinear(
  123. hidden_size,
  124. 4 * hidden_size,
  125. linear_method=linear_method,
  126. )
  127. quant_config = getattr(linear_method, "quant_config", None)
  128. self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
  129. self.dense_4h_to_h = RowParallelLinear(
  130. 4 * hidden_size,
  131. hidden_size,
  132. linear_method=linear_method,
  133. )
  134. def forward(self, x: torch.Tensor) -> torch.Tensor:
  135. x, _ = self.dense_h_to_4h(x)
  136. x = self.gelu_impl(x)
  137. x, _ = self.dense_4h_to_h(x)
  138. return x
  139. class BloomBlock(nn.Module):
  140. def __init__(
  141. self,
  142. config: BloomConfig,
  143. linear_method: Optional[LinearMethodBase] = None,
  144. ):
  145. super().__init__()
  146. hidden_size = config.hidden_size
  147. self.input_layernorm = nn.LayerNorm(hidden_size,
  148. eps=config.layer_norm_epsilon)
  149. self.self_attention = BloomAttention(config, linear_method)
  150. self.post_attention_layernorm = nn.LayerNorm(
  151. hidden_size, eps=config.layer_norm_epsilon)
  152. self.mlp = BloomMLP(config, linear_method)
  153. self.apply_residual_connection_post_layernorm = (
  154. config.apply_residual_connection_post_layernorm)
  155. def forward(
  156. self,
  157. position_ids: torch.Tensor,
  158. hidden_states: torch.Tensor,
  159. kv_cache: torch.Tensor,
  160. attn_metadata: AttentionMetadata,
  161. ) -> torch.Tensor:
  162. # Layer norm at the beginning of the transformer layer.
  163. layernorm_output = self.input_layernorm(hidden_states)
  164. # Layer norm post the self attention.
  165. if self.apply_residual_connection_post_layernorm:
  166. residual = layernorm_output
  167. else:
  168. residual = hidden_states
  169. # Self attention.
  170. attention_output = self.self_attention(
  171. position_ids=position_ids,
  172. hidden_states=layernorm_output,
  173. kv_cache=kv_cache,
  174. attn_metadata=attn_metadata,
  175. )
  176. attention_output = attention_output + residual
  177. layernorm_output = self.post_attention_layernorm(attention_output)
  178. # Get residual
  179. if self.apply_residual_connection_post_layernorm:
  180. residual = layernorm_output
  181. else:
  182. residual = attention_output
  183. # MLP.
  184. output = self.mlp(layernorm_output) + residual
  185. return output
  186. class BloomModel(nn.Module):
  187. def __init__(
  188. self,
  189. config: BloomConfig,
  190. linear_method: Optional[LinearMethodBase] = None,
  191. ):
  192. super().__init__()
  193. self.embed_dim = config.hidden_size
  194. # Embedding + LN Embedding
  195. self.word_embeddings = VocabParallelEmbedding(
  196. config.vocab_size, self.embed_dim, linear_method=linear_method)
  197. self.word_embeddings_layernorm = nn.LayerNorm(
  198. self.embed_dim, eps=config.layer_norm_epsilon)
  199. # Transformer blocks
  200. self.h = nn.ModuleList([
  201. BloomBlock(config, linear_method)
  202. for _ in range(config.num_hidden_layers)
  203. ])
  204. # Final Layer Norm
  205. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  206. def forward(
  207. self,
  208. input_ids: torch.Tensor,
  209. position_ids: torch.Tensor,
  210. kv_caches: List[torch.Tensor],
  211. attn_metadata: AttentionMetadata,
  212. ) -> torch.Tensor:
  213. hidden_states = self.word_embeddings(input_ids)
  214. hidden_states = self.word_embeddings_layernorm(hidden_states)
  215. for i in range(len(self.h)):
  216. layer = self.h[i]
  217. hidden_states = layer(
  218. position_ids,
  219. hidden_states,
  220. kv_caches[i],
  221. attn_metadata,
  222. )
  223. hidden_states = self.ln_f(hidden_states)
  224. return hidden_states
  225. class BloomForCausalLM(nn.Module):
  226. def __init__(
  227. self,
  228. config: BloomConfig,
  229. linear_method: Optional[LinearMethodBase] = None,
  230. ):
  231. super().__init__()
  232. self.config = config
  233. self.linear_method = linear_method
  234. self.transformer = BloomModel(config, linear_method)
  235. self.lm_head_weight = self.transformer.word_embeddings.weight
  236. self.lm_head = ParallelLMHead(config.vocab_size,
  237. config.hidden_size,
  238. linear_method=linear_method)
  239. self.logits_processor = LogitsProcessor(config.vocab_size,
  240. config.tokenizer_vocab_size)
  241. self.sampler = Sampler()
  242. def forward(
  243. self,
  244. input_ids: torch.Tensor,
  245. positions: torch.Tensor,
  246. kv_caches: List[torch.Tensor],
  247. attn_metadata: AttentionMetadata,
  248. ) -> torch.Tensor:
  249. hidden_states = self.transformer(input_ids, positions, kv_caches,
  250. attn_metadata)
  251. return hidden_states
  252. def compute_logits(self, hidden_states: torch.Tensor,
  253. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  254. logits = self.logits_processor(self.lm_head, hidden_states,
  255. sampling_metadata)
  256. return logits
  257. def sample(
  258. self,
  259. logits: torch.Tensor,
  260. sampling_metadata: SamplingMetadata,
  261. ) -> Optional[SamplerOutput]:
  262. next_tokens = self.sampler(logits, sampling_metadata)
  263. return next_tokens
  264. def load_weights(self,
  265. model_name_or_path: str,
  266. cache_dir: Optional[str] = None,
  267. load_format: str = "auto",
  268. revision: Optional[str] = None):
  269. params_dict = dict(self.named_parameters(remove_duplicate=False))
  270. for name, loaded_weight in hf_model_weights_iterator(
  271. model_name_or_path, cache_dir, load_format, revision,
  272. self.config):
  273. if "lm_head" in name and name not in params_dict:
  274. continue
  275. if not name.startswith("transformer."):
  276. name = "transformer." + name
  277. param = params_dict[name]
  278. if "word_embeddings" in name:
  279. # Copy word embedding to lm_head
  280. head_name = name.replace("transformer.word_embeddings",
  281. "lm_head")
  282. if head_name in params_dict:
  283. lm_head_param = params_dict[head_name]
  284. weight_loader = getattr(lm_head_param, "weight_loader",
  285. default_weight_loader)
  286. weight_loader(lm_head_param, loaded_weight)
  287. if "query_key_value" in name:
  288. # NOTE: BLOOM's fused QKV's output_dim has the shape of
  289. # (num_heads * 3 * head_size), while the
  290. # required shape is (3 * num_heads * head_size).
  291. # Thus, we need weight conversion.
  292. output_dim = getattr(param, "output_dim", None)
  293. num_heads = self.config.num_attention_heads
  294. if output_dim is not None:
  295. loaded_weight_shape = loaded_weight.shape
  296. loaded_weight = loaded_weight.view(
  297. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  298. loaded_weight_shape[output_dim + 1:])
  299. loaded_weight = loaded_weight.transpose(
  300. output_dim, output_dim + 1)
  301. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  302. weight_loader = getattr(param, "weight_loader",
  303. default_weight_loader)
  304. weight_loader(param, loaded_weight)