bloom.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only BLOOM model compatible with HuggingFace weights."""
  20. import math
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import BloomConfig
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  28. get_tensor_model_parallel_world_size)
  29. from aphrodite.modeling.layers.activation import get_act_fn
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. LinearMethodBase,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  37. VocabParallelEmbedding
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  41. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  42. base = torch.tensor(
  43. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  44. dtype=torch.float32,
  45. )
  46. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  47. slopes = torch.pow(base, powers)
  48. if closest_power_of_2 != total_num_heads:
  49. extra_base = torch.tensor(
  50. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  51. dtype=torch.float32,
  52. )
  53. num_remaining_heads = min(closest_power_of_2,
  54. total_num_heads - closest_power_of_2)
  55. extra_powers = torch.arange(start=1,
  56. end=1 + 2 * num_remaining_heads,
  57. step=2,
  58. dtype=torch.int32)
  59. slopes = torch.cat(
  60. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  61. return slopes
  62. class BloomAttention(nn.Module):
  63. def __init__(
  64. self,
  65. config: BloomConfig,
  66. linear_method: Optional[LinearMethodBase] = None,
  67. ):
  68. super().__init__()
  69. self.hidden_size = config.hidden_size
  70. self.total_num_heads = config.n_head
  71. self.head_dim = self.hidden_size // self.total_num_heads
  72. assert self.head_dim * self.total_num_heads == self.hidden_size
  73. tp_world_size = get_tensor_model_parallel_world_size()
  74. assert self.total_num_heads % tp_world_size == 0
  75. self.num_heads = self.total_num_heads // tp_world_size
  76. self.query_key_value = QKVParallelLinear(
  77. self.hidden_size,
  78. self.head_dim,
  79. self.total_num_heads,
  80. bias=True,
  81. linear_method=linear_method,
  82. )
  83. self.dense = RowParallelLinear(
  84. self.hidden_size,
  85. self.hidden_size,
  86. bias=True,
  87. linear_method=linear_method,
  88. )
  89. # Create the alibi slopes and slice them.
  90. tp_rank = get_tensor_model_parallel_rank()
  91. head_start = tp_rank * self.num_heads
  92. head_end = (tp_rank + 1) * self.num_heads
  93. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  94. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  95. scaling = self.head_dim**-0.5
  96. self.attn = Attention(self.num_heads,
  97. self.head_dim,
  98. scaling,
  99. alibi_slopes=alibi_slopes)
  100. def forward(
  101. self,
  102. position_ids: torch.Tensor,
  103. hidden_states: torch.Tensor,
  104. kv_cache: torch.Tensor,
  105. attn_metadata: AttentionMetadata,
  106. ) -> torch.Tensor:
  107. del position_ids # Unused.
  108. qkv, _ = self.query_key_value(hidden_states)
  109. q, k, v = qkv.chunk(chunks=3, dim=-1)
  110. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  111. output, _ = self.dense(attn_output)
  112. return output
  113. class BloomMLP(nn.Module):
  114. def __init__(
  115. self,
  116. config: BloomConfig,
  117. linear_method: Optional[LinearMethodBase] = None,
  118. ):
  119. super().__init__()
  120. hidden_size = config.hidden_size
  121. self.dense_h_to_4h = ColumnParallelLinear(
  122. hidden_size,
  123. 4 * hidden_size,
  124. linear_method=linear_method,
  125. )
  126. quant_config = getattr(linear_method, "quant_config", None)
  127. self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
  128. self.dense_4h_to_h = RowParallelLinear(
  129. 4 * hidden_size,
  130. hidden_size,
  131. linear_method=linear_method,
  132. )
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. x, _ = self.dense_h_to_4h(x)
  135. x = self.gelu_impl(x)
  136. x, _ = self.dense_4h_to_h(x)
  137. return x
  138. class BloomBlock(nn.Module):
  139. def __init__(
  140. self,
  141. config: BloomConfig,
  142. linear_method: Optional[LinearMethodBase] = None,
  143. ):
  144. super().__init__()
  145. hidden_size = config.hidden_size
  146. self.input_layernorm = nn.LayerNorm(hidden_size,
  147. eps=config.layer_norm_epsilon)
  148. self.self_attention = BloomAttention(config, linear_method)
  149. self.post_attention_layernorm = nn.LayerNorm(
  150. hidden_size, eps=config.layer_norm_epsilon)
  151. self.mlp = BloomMLP(config, linear_method)
  152. self.apply_residual_connection_post_layernorm = (
  153. config.apply_residual_connection_post_layernorm)
  154. def forward(
  155. self,
  156. position_ids: torch.Tensor,
  157. hidden_states: torch.Tensor,
  158. kv_cache: torch.Tensor,
  159. attn_metadata: AttentionMetadata,
  160. ) -> torch.Tensor:
  161. # Layer norm at the beginning of the transformer layer.
  162. layernorm_output = self.input_layernorm(hidden_states)
  163. # Layer norm post the self attention.
  164. if self.apply_residual_connection_post_layernorm:
  165. residual = layernorm_output
  166. else:
  167. residual = hidden_states
  168. # Self attention.
  169. attention_output = self.self_attention(
  170. position_ids=position_ids,
  171. hidden_states=layernorm_output,
  172. kv_cache=kv_cache,
  173. attn_metadata=attn_metadata,
  174. )
  175. attention_output = attention_output + residual
  176. layernorm_output = self.post_attention_layernorm(attention_output)
  177. # Get residual
  178. if self.apply_residual_connection_post_layernorm:
  179. residual = layernorm_output
  180. else:
  181. residual = attention_output
  182. # MLP.
  183. output = self.mlp(layernorm_output) + residual
  184. return output
  185. class BloomModel(nn.Module):
  186. def __init__(
  187. self,
  188. config: BloomConfig,
  189. linear_method: Optional[LinearMethodBase] = None,
  190. ):
  191. super().__init__()
  192. self.embed_dim = config.hidden_size
  193. # Embedding + LN Embedding
  194. self.word_embeddings = VocabParallelEmbedding(
  195. config.vocab_size,
  196. self.embed_dim,
  197. )
  198. self.word_embeddings_layernorm = nn.LayerNorm(
  199. self.embed_dim, eps=config.layer_norm_epsilon)
  200. # Transformer blocks
  201. self.h = nn.ModuleList([
  202. BloomBlock(config, linear_method)
  203. for _ in range(config.num_hidden_layers)
  204. ])
  205. # Final Layer Norm
  206. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  207. def forward(
  208. self,
  209. input_ids: torch.Tensor,
  210. position_ids: torch.Tensor,
  211. kv_caches: List[torch.Tensor],
  212. attn_metadata: AttentionMetadata,
  213. ) -> torch.Tensor:
  214. hidden_states = self.word_embeddings(input_ids)
  215. hidden_states = self.word_embeddings_layernorm(hidden_states)
  216. for i in range(len(self.h)):
  217. layer = self.h[i]
  218. hidden_states = layer(
  219. position_ids,
  220. hidden_states,
  221. kv_caches[i],
  222. attn_metadata,
  223. )
  224. hidden_states = self.ln_f(hidden_states)
  225. return hidden_states
  226. class BloomForCausalLM(nn.Module):
  227. def __init__(
  228. self,
  229. config: BloomConfig,
  230. linear_method: Optional[LinearMethodBase] = None,
  231. ):
  232. super().__init__()
  233. self.config = config
  234. self.linear_method = linear_method
  235. self.transformer = BloomModel(config, linear_method)
  236. self.lm_head_weight = self.transformer.word_embeddings.weight
  237. self.logits_processor = LogitsProcessor(config.vocab_size)
  238. self.sampler = Sampler()
  239. def forward(
  240. self,
  241. input_ids: torch.Tensor,
  242. positions: torch.Tensor,
  243. kv_caches: List[torch.Tensor],
  244. attn_metadata: AttentionMetadata,
  245. ) -> torch.Tensor:
  246. hidden_states = self.transformer(input_ids, positions, kv_caches,
  247. attn_metadata)
  248. return hidden_states
  249. def compute_logits(self, hidden_states: torch.Tensor,
  250. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  251. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  252. sampling_metadata)
  253. return logits
  254. def sample(
  255. self,
  256. logits: torch.Tensor,
  257. sampling_metadata: SamplingMetadata,
  258. ) -> Optional[SamplerOutput]:
  259. next_tokens = self.sampler(logits, sampling_metadata)
  260. return next_tokens
  261. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  262. params_dict = dict(self.named_parameters(remove_duplicate=False))
  263. for name, loaded_weight in weights:
  264. if name == "lm_head.weight":
  265. continue
  266. if not name.startswith("transformer."):
  267. name = "transformer." + name
  268. param = params_dict[name]
  269. if "query_key_value" in name:
  270. # NOTE: BLOOM's fused QKV's output_dim has the shape of
  271. # (num_heads * 3 * head_size), while the
  272. # required shape is (3 * num_heads * head_size).
  273. # Thus, we need weight conversion.
  274. output_dim = getattr(param, "output_dim", None)
  275. num_heads = self.config.num_attention_heads
  276. if output_dim is not None:
  277. loaded_weight_shape = loaded_weight.shape
  278. loaded_weight = loaded_weight.view(
  279. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  280. loaded_weight_shape[output_dim + 1:])
  281. loaded_weight = loaded_weight.transpose(
  282. output_dim, output_dim + 1)
  283. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  284. weight_loader = getattr(param, "weight_loader",
  285. default_weight_loader)
  286. weight_loader(param, loaded_weight)