bloom.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. """Inference-only BLOOM model compatible with HuggingFace weights."""
  19. import math
  20. from typing import Iterable, List, Optional, Tuple
  21. import torch
  22. from torch import nn
  23. from transformers import BloomConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import CacheConfig
  26. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  27. from aphrodite.common.utils import progress_bar
  28. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  29. get_tensor_model_parallel_world_size)
  30. from aphrodite.modeling.layers.activation import get_act_fn
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  42. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  43. base = torch.tensor(
  44. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  45. dtype=torch.float32,
  46. )
  47. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  48. slopes = torch.pow(base, powers)
  49. if closest_power_of_2 != total_num_heads:
  50. extra_base = torch.tensor(
  51. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  52. dtype=torch.float32,
  53. )
  54. num_remaining_heads = min(closest_power_of_2,
  55. total_num_heads - closest_power_of_2)
  56. extra_powers = torch.arange(start=1,
  57. end=1 + 2 * num_remaining_heads,
  58. step=2,
  59. dtype=torch.int32)
  60. slopes = torch.cat(
  61. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  62. return slopes
  63. class BloomAttention(nn.Module):
  64. def __init__(
  65. self,
  66. config: BloomConfig,
  67. cache_config: Optional[CacheConfig] = None,
  68. quant_config: Optional[QuantizationConfig] = None,
  69. ):
  70. super().__init__()
  71. self.hidden_size = config.hidden_size
  72. self.total_num_heads = config.n_head
  73. self.head_dim = self.hidden_size // self.total_num_heads
  74. assert self.head_dim * self.total_num_heads == self.hidden_size
  75. tp_world_size = get_tensor_model_parallel_world_size()
  76. assert self.total_num_heads % tp_world_size == 0
  77. self.num_heads = self.total_num_heads // tp_world_size
  78. self.query_key_value = QKVParallelLinear(
  79. self.hidden_size,
  80. self.head_dim,
  81. self.total_num_heads,
  82. bias=True,
  83. quant_config=quant_config,
  84. )
  85. self.dense = RowParallelLinear(
  86. self.hidden_size,
  87. self.hidden_size,
  88. bias=True,
  89. quant_config=quant_config,
  90. )
  91. # Create the alibi slopes and slice them.
  92. tp_rank = get_tensor_model_parallel_rank()
  93. head_start = tp_rank * self.num_heads
  94. head_end = (tp_rank + 1) * self.num_heads
  95. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  96. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  97. scaling = self.head_dim**-0.5
  98. self.attn = Attention(self.num_heads,
  99. self.head_dim,
  100. scaling,
  101. alibi_slopes=alibi_slopes,
  102. cache_config=cache_config,
  103. quant_config=quant_config)
  104. def forward(
  105. self,
  106. position_ids: torch.Tensor,
  107. hidden_states: torch.Tensor,
  108. kv_cache: torch.Tensor,
  109. attn_metadata: AttentionMetadata,
  110. ) -> torch.Tensor:
  111. del position_ids # Unused.
  112. qkv, _ = self.query_key_value(hidden_states)
  113. q, k, v = qkv.chunk(chunks=3, dim=-1)
  114. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  115. output, _ = self.dense(attn_output)
  116. return output
  117. class BloomMLP(nn.Module):
  118. def __init__(
  119. self,
  120. config: BloomConfig,
  121. quant_config: Optional[QuantizationConfig] = None,
  122. ):
  123. super().__init__()
  124. hidden_size = config.hidden_size
  125. self.dense_h_to_4h = ColumnParallelLinear(
  126. hidden_size,
  127. 4 * hidden_size,
  128. quant_config=quant_config,
  129. )
  130. self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
  131. self.dense_4h_to_h = RowParallelLinear(
  132. 4 * hidden_size,
  133. hidden_size,
  134. quant_config=quant_config,
  135. )
  136. def forward(self, x: torch.Tensor) -> torch.Tensor:
  137. x, _ = self.dense_h_to_4h(x)
  138. x = self.gelu_impl(x)
  139. x, _ = self.dense_4h_to_h(x)
  140. return x
  141. class BloomBlock(nn.Module):
  142. def __init__(
  143. self,
  144. config: BloomConfig,
  145. cache_config: Optional[CacheConfig] = None,
  146. quant_config: Optional[QuantizationConfig] = None,
  147. ):
  148. super().__init__()
  149. hidden_size = config.hidden_size
  150. self.input_layernorm = nn.LayerNorm(hidden_size,
  151. eps=config.layer_norm_epsilon)
  152. self.self_attention = BloomAttention(config, cache_config,
  153. quant_config)
  154. self.post_attention_layernorm = nn.LayerNorm(
  155. hidden_size, eps=config.layer_norm_epsilon)
  156. self.mlp = BloomMLP(config, quant_config)
  157. self.apply_residual_connection_post_layernorm = (
  158. config.apply_residual_connection_post_layernorm)
  159. def forward(
  160. self,
  161. position_ids: torch.Tensor,
  162. hidden_states: torch.Tensor,
  163. kv_cache: torch.Tensor,
  164. attn_metadata: AttentionMetadata,
  165. ) -> torch.Tensor:
  166. # Layer norm at the beginning of the transformer layer.
  167. layernorm_output = self.input_layernorm(hidden_states)
  168. # Layer norm post the self attention.
  169. if self.apply_residual_connection_post_layernorm:
  170. residual = layernorm_output
  171. else:
  172. residual = hidden_states
  173. # Self attention.
  174. attention_output = self.self_attention(
  175. position_ids=position_ids,
  176. hidden_states=layernorm_output,
  177. kv_cache=kv_cache,
  178. attn_metadata=attn_metadata,
  179. )
  180. attention_output = attention_output + residual
  181. layernorm_output = self.post_attention_layernorm(attention_output)
  182. # Get residual
  183. if self.apply_residual_connection_post_layernorm:
  184. residual = layernorm_output
  185. else:
  186. residual = attention_output
  187. # MLP.
  188. output = self.mlp(layernorm_output) + residual
  189. return output
  190. class BloomModel(nn.Module):
  191. def __init__(
  192. self,
  193. config: BloomConfig,
  194. cache_config: Optional[CacheConfig] = None,
  195. quant_config: Optional[QuantizationConfig] = None,
  196. ):
  197. super().__init__()
  198. self.embed_dim = config.hidden_size
  199. # Embedding + LN Embedding
  200. self.word_embeddings = VocabParallelEmbedding(
  201. config.vocab_size,
  202. self.embed_dim,
  203. )
  204. self.word_embeddings_layernorm = nn.LayerNorm(
  205. self.embed_dim, eps=config.layer_norm_epsilon)
  206. # Transformer blocks
  207. self.h = nn.ModuleList([
  208. BloomBlock(config, cache_config, quant_config)
  209. for _ in range(config.num_hidden_layers)
  210. ])
  211. # Final Layer Norm
  212. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  213. def forward(
  214. self,
  215. input_ids: torch.Tensor,
  216. position_ids: torch.Tensor,
  217. kv_caches: List[torch.Tensor],
  218. attn_metadata: AttentionMetadata,
  219. ) -> torch.Tensor:
  220. hidden_states = self.word_embeddings(input_ids)
  221. hidden_states = self.word_embeddings_layernorm(hidden_states)
  222. for i in range(len(self.h)):
  223. layer = self.h[i]
  224. hidden_states = layer(
  225. position_ids,
  226. hidden_states,
  227. kv_caches[i],
  228. attn_metadata,
  229. )
  230. hidden_states = self.ln_f(hidden_states)
  231. return hidden_states
  232. class BloomForCausalLM(nn.Module):
  233. def __init__(
  234. self,
  235. config: BloomConfig,
  236. cache_config: Optional[CacheConfig] = None,
  237. quant_config: Optional[QuantizationConfig] = None,
  238. ):
  239. super().__init__()
  240. self.config = config
  241. self.quant_config = quant_config
  242. self.transformer = BloomModel(config, cache_config, quant_config)
  243. self.lm_head = self.transformer.word_embeddings
  244. self.logits_processor = LogitsProcessor(config.vocab_size)
  245. self.sampler = Sampler()
  246. def forward(
  247. self,
  248. input_ids: torch.Tensor,
  249. positions: torch.Tensor,
  250. kv_caches: List[torch.Tensor],
  251. attn_metadata: AttentionMetadata,
  252. intermediate_tensors: Optional[IntermediateTensors] = None,
  253. ) -> torch.Tensor:
  254. hidden_states = self.transformer(input_ids, positions, kv_caches,
  255. attn_metadata)
  256. return hidden_states
  257. def compute_logits(
  258. self,
  259. hidden_states: torch.Tensor,
  260. sampling_metadata: SamplingMetadata,
  261. ) -> Optional[torch.Tensor]:
  262. logits = self.logits_processor(self.lm_head, hidden_states,
  263. sampling_metadata)
  264. return logits
  265. def sample(
  266. self,
  267. logits: torch.Tensor,
  268. sampling_metadata: SamplingMetadata,
  269. ) -> Optional[SamplerOutput]:
  270. next_tokens = self.sampler(logits, sampling_metadata)
  271. return next_tokens
  272. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  273. params_dict = dict(self.named_parameters(remove_duplicate=False))
  274. weights_list = list(weights)
  275. for name, loaded_weight in progress_bar(weights_list,
  276. desc="Loading modules..."):
  277. if name == "lm_head.weight":
  278. continue
  279. if not name.startswith("transformer."):
  280. name = "transformer." + name
  281. param = params_dict[name]
  282. if "query_key_value" in name:
  283. # NOTE: BLOOM's fused QKV's output_dim has the shape of
  284. # (num_heads * 3 * head_size), while the
  285. # required shape is (3 * num_heads * head_size).
  286. # Thus, we need weight conversion.
  287. output_dim = getattr(param, "output_dim", None)
  288. num_heads = self.config.num_attention_heads
  289. if output_dim is not None:
  290. loaded_weight_shape = loaded_weight.shape
  291. loaded_weight = loaded_weight.view(
  292. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  293. loaded_weight_shape[output_dim + 1:])
  294. loaded_weight = loaded_weight.transpose(
  295. output_dim, output_dim + 1)
  296. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  297. weight_loader = getattr(param, "weight_loader",
  298. default_weight_loader)
  299. weight_loader(param, loaded_weight)