bloom.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only BLOOM model compatible with HuggingFace weights."""
  19. import math
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import BloomConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import CacheConfig
  26. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  27. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  28. get_tensor_model_parallel_world_size)
  29. from aphrodite.modeling.layers.activation import get_act_fn
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  41. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  42. base = torch.tensor(
  43. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  44. dtype=torch.float32,
  45. )
  46. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  47. slopes = torch.pow(base, powers)
  48. if closest_power_of_2 != total_num_heads:
  49. extra_base = torch.tensor(
  50. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  51. dtype=torch.float32,
  52. )
  53. num_remaining_heads = min(closest_power_of_2,
  54. total_num_heads - closest_power_of_2)
  55. extra_powers = torch.arange(start=1,
  56. end=1 + 2 * num_remaining_heads,
  57. step=2,
  58. dtype=torch.int32)
  59. slopes = torch.cat(
  60. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  61. return slopes
  62. class BloomAttention(nn.Module):
  63. def __init__(
  64. self,
  65. config: BloomConfig,
  66. cache_config: Optional[CacheConfig] = None,
  67. quant_config: Optional[QuantizationConfig] = None,
  68. ):
  69. super().__init__()
  70. self.hidden_size = config.hidden_size
  71. self.total_num_heads = config.n_head
  72. self.head_dim = self.hidden_size // self.total_num_heads
  73. assert self.head_dim * self.total_num_heads == self.hidden_size
  74. tp_world_size = get_tensor_model_parallel_world_size()
  75. assert self.total_num_heads % tp_world_size == 0
  76. self.num_heads = self.total_num_heads // tp_world_size
  77. self.query_key_value = QKVParallelLinear(
  78. self.hidden_size,
  79. self.head_dim,
  80. self.total_num_heads,
  81. bias=True,
  82. quant_config=quant_config,
  83. )
  84. self.dense = RowParallelLinear(
  85. self.hidden_size,
  86. self.hidden_size,
  87. bias=True,
  88. quant_config=quant_config,
  89. )
  90. # Create the alibi slopes and slice them.
  91. tp_rank = get_tensor_model_parallel_rank()
  92. head_start = tp_rank * self.num_heads
  93. head_end = (tp_rank + 1) * self.num_heads
  94. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  95. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  96. scaling = self.head_dim**-0.5
  97. self.attn = Attention(self.num_heads,
  98. self.head_dim,
  99. scaling,
  100. alibi_slopes=alibi_slopes,
  101. cache_config=cache_config,
  102. quant_config=quant_config)
  103. def forward(
  104. self,
  105. position_ids: torch.Tensor,
  106. hidden_states: torch.Tensor,
  107. kv_cache: torch.Tensor,
  108. attn_metadata: AttentionMetadata,
  109. ) -> torch.Tensor:
  110. del position_ids # Unused.
  111. qkv, _ = self.query_key_value(hidden_states)
  112. q, k, v = qkv.chunk(chunks=3, dim=-1)
  113. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  114. output, _ = self.dense(attn_output)
  115. return output
  116. class BloomMLP(nn.Module):
  117. def __init__(
  118. self,
  119. config: BloomConfig,
  120. quant_config: Optional[QuantizationConfig] = None,
  121. ):
  122. super().__init__()
  123. hidden_size = config.hidden_size
  124. self.dense_h_to_4h = ColumnParallelLinear(
  125. hidden_size,
  126. 4 * hidden_size,
  127. quant_config=quant_config,
  128. )
  129. self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
  130. self.dense_4h_to_h = RowParallelLinear(
  131. 4 * hidden_size,
  132. hidden_size,
  133. quant_config=quant_config,
  134. )
  135. def forward(self, x: torch.Tensor) -> torch.Tensor:
  136. x, _ = self.dense_h_to_4h(x)
  137. x = self.gelu_impl(x)
  138. x, _ = self.dense_4h_to_h(x)
  139. return x
  140. class BloomBlock(nn.Module):
  141. def __init__(
  142. self,
  143. config: BloomConfig,
  144. cache_config: Optional[CacheConfig] = None,
  145. quant_config: Optional[QuantizationConfig] = None,
  146. ):
  147. super().__init__()
  148. hidden_size = config.hidden_size
  149. self.input_layernorm = nn.LayerNorm(hidden_size,
  150. eps=config.layer_norm_epsilon)
  151. self.self_attention = BloomAttention(config, cache_config,
  152. quant_config)
  153. self.post_attention_layernorm = nn.LayerNorm(
  154. hidden_size, eps=config.layer_norm_epsilon)
  155. self.mlp = BloomMLP(config, quant_config)
  156. self.apply_residual_connection_post_layernorm = (
  157. config.apply_residual_connection_post_layernorm)
  158. def forward(
  159. self,
  160. position_ids: torch.Tensor,
  161. hidden_states: torch.Tensor,
  162. kv_cache: torch.Tensor,
  163. attn_metadata: AttentionMetadata,
  164. ) -> torch.Tensor:
  165. # Layer norm at the beginning of the transformer layer.
  166. layernorm_output = self.input_layernorm(hidden_states)
  167. # Layer norm post the self attention.
  168. if self.apply_residual_connection_post_layernorm:
  169. residual = layernorm_output
  170. else:
  171. residual = hidden_states
  172. # Self attention.
  173. attention_output = self.self_attention(
  174. position_ids=position_ids,
  175. hidden_states=layernorm_output,
  176. kv_cache=kv_cache,
  177. attn_metadata=attn_metadata,
  178. )
  179. attention_output = attention_output + residual
  180. layernorm_output = self.post_attention_layernorm(attention_output)
  181. # Get residual
  182. if self.apply_residual_connection_post_layernorm:
  183. residual = layernorm_output
  184. else:
  185. residual = attention_output
  186. # MLP.
  187. output = self.mlp(layernorm_output) + residual
  188. return output
  189. class BloomModel(nn.Module):
  190. def __init__(
  191. self,
  192. config: BloomConfig,
  193. cache_config: Optional[CacheConfig] = None,
  194. quant_config: Optional[QuantizationConfig] = None,
  195. ):
  196. super().__init__()
  197. self.embed_dim = config.hidden_size
  198. # Embedding + LN Embedding
  199. self.word_embeddings = VocabParallelEmbedding(
  200. config.vocab_size,
  201. self.embed_dim,
  202. )
  203. self.word_embeddings_layernorm = nn.LayerNorm(
  204. self.embed_dim, eps=config.layer_norm_epsilon)
  205. # Transformer blocks
  206. self.h = nn.ModuleList([
  207. BloomBlock(config, cache_config, quant_config)
  208. for _ in range(config.num_hidden_layers)
  209. ])
  210. # Final Layer Norm
  211. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  212. def forward(
  213. self,
  214. input_ids: torch.Tensor,
  215. position_ids: torch.Tensor,
  216. kv_caches: List[torch.Tensor],
  217. attn_metadata: AttentionMetadata,
  218. ) -> torch.Tensor:
  219. hidden_states = self.word_embeddings(input_ids)
  220. hidden_states = self.word_embeddings_layernorm(hidden_states)
  221. for i in range(len(self.h)):
  222. layer = self.h[i]
  223. hidden_states = layer(
  224. position_ids,
  225. hidden_states,
  226. kv_caches[i],
  227. attn_metadata,
  228. )
  229. hidden_states = self.ln_f(hidden_states)
  230. return hidden_states
  231. class BloomForCausalLM(nn.Module):
  232. def __init__(
  233. self,
  234. config: BloomConfig,
  235. cache_config: Optional[CacheConfig] = None,
  236. quant_config: Optional[QuantizationConfig] = None,
  237. ):
  238. super().__init__()
  239. self.config = config
  240. self.quant_config = quant_config
  241. self.transformer = BloomModel(config, cache_config, quant_config)
  242. self.lm_head = self.transformer.word_embeddings
  243. self.logits_processor = LogitsProcessor(config.vocab_size)
  244. self.sampler = Sampler()
  245. def forward(
  246. self,
  247. input_ids: torch.Tensor,
  248. positions: torch.Tensor,
  249. kv_caches: List[torch.Tensor],
  250. attn_metadata: AttentionMetadata,
  251. intermediate_tensors: Optional[IntermediateTensors] = None,
  252. ) -> torch.Tensor:
  253. hidden_states = self.transformer(input_ids, positions, kv_caches,
  254. attn_metadata)
  255. return hidden_states
  256. def compute_logits(
  257. self,
  258. hidden_states: torch.Tensor,
  259. sampling_metadata: SamplingMetadata,
  260. ) -> Optional[torch.Tensor]:
  261. logits = self.logits_processor(self.lm_head, hidden_states,
  262. sampling_metadata)
  263. return logits
  264. def sample(
  265. self,
  266. logits: torch.Tensor,
  267. sampling_metadata: SamplingMetadata,
  268. ) -> Optional[SamplerOutput]:
  269. next_tokens = self.sampler(logits, sampling_metadata)
  270. return next_tokens
  271. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  272. params_dict = dict(self.named_parameters(remove_duplicate=False))
  273. for name, loaded_weight in weights:
  274. if name == "lm_head.weight":
  275. continue
  276. if not name.startswith("transformer."):
  277. name = "transformer." + name
  278. param = params_dict[name]
  279. if "query_key_value" in name:
  280. # NOTE: BLOOM's fused QKV's output_dim has the shape of
  281. # (num_heads * 3 * head_size), while the
  282. # required shape is (3 * num_heads * head_size).
  283. # Thus, we need weight conversion.
  284. output_dim = getattr(param, "output_dim", None)
  285. num_heads = self.config.num_attention_heads
  286. if output_dim is not None:
  287. loaded_weight_shape = loaded_weight.shape
  288. loaded_weight = loaded_weight.view(
  289. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  290. loaded_weight_shape[output_dim + 1:])
  291. loaded_weight = loaded_weight.transpose(
  292. output_dim, output_dim + 1)
  293. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  294. weight_loader = getattr(param, "weight_loader",
  295. default_weight_loader)
  296. weight_loader(param, loaded_weight)