bloom.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The CacheFlow team.
  6. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """Inference-only BLOOM model compatible with HuggingFace weights."""
  20. import math
  21. from typing import List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import BloomConfig
  25. from aphrodite.modeling.metadata import InputMetadata
  26. from aphrodite.modeling.layers.activation import get_act_fn
  27. from aphrodite.modeling.layers.attention import PagedAttention
  28. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  29. LinearMethodBase,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.modeling.megatron.parallel_state import (
  36. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  39. hf_model_weights_iterator)
  40. from aphrodite.common.sequence import SamplerOutput
  41. KVCache = Tuple[torch.Tensor, torch.Tensor]
  42. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  43. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  44. base = torch.tensor(
  45. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  46. dtype=torch.float32,
  47. )
  48. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  49. slopes = torch.pow(base, powers)
  50. if closest_power_of_2 != total_num_heads:
  51. extra_base = torch.tensor(
  52. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  53. dtype=torch.float32,
  54. )
  55. num_remaining_heads = min(closest_power_of_2,
  56. total_num_heads - closest_power_of_2)
  57. extra_powers = torch.arange(start=1,
  58. end=1 + 2 * num_remaining_heads,
  59. step=2,
  60. dtype=torch.int32)
  61. slopes = torch.cat(
  62. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  63. return slopes
  64. class BloomAttention(nn.Module):
  65. def __init__(
  66. self,
  67. config: BloomConfig,
  68. linear_method: Optional[LinearMethodBase] = None,
  69. ):
  70. super().__init__()
  71. self.hidden_size = config.hidden_size
  72. self.total_num_heads = config.n_head
  73. self.head_dim = self.hidden_size // self.total_num_heads
  74. assert self.head_dim * self.total_num_heads == self.hidden_size
  75. tp_world_size = get_tensor_model_parallel_world_size()
  76. assert self.total_num_heads % tp_world_size == 0
  77. self.num_heads = self.total_num_heads // tp_world_size
  78. self.query_key_value = QKVParallelLinear(
  79. self.hidden_size,
  80. self.head_dim,
  81. self.total_num_heads,
  82. bias=True,
  83. linear_method=linear_method,
  84. )
  85. self.dense = RowParallelLinear(
  86. self.hidden_size,
  87. self.hidden_size,
  88. bias=True,
  89. linear_method=linear_method,
  90. )
  91. # Create the alibi slopes and slice them.
  92. tp_rank = get_tensor_model_parallel_rank()
  93. head_start = tp_rank * self.num_heads
  94. head_end = (tp_rank + 1) * self.num_heads
  95. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  96. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  97. scaling = self.head_dim**-0.5
  98. self.attn = PagedAttention(self.num_heads,
  99. self.head_dim,
  100. scaling,
  101. alibi_slopes=alibi_slopes)
  102. def forward(
  103. self,
  104. position_ids: torch.Tensor,
  105. hidden_states: torch.Tensor,
  106. kv_cache: KVCache,
  107. input_metadata: InputMetadata,
  108. ) -> torch.Tensor:
  109. del position_ids # Unused.
  110. qkv, _ = self.query_key_value(hidden_states)
  111. q, k, v = qkv.chunk(chunks=3, dim=-1)
  112. k_cache, v_cache = kv_cache
  113. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  114. output, _ = self.dense(attn_output)
  115. return output
  116. class BloomMLP(nn.Module):
  117. def __init__(
  118. self,
  119. config: BloomConfig,
  120. linear_method: Optional[LinearMethodBase] = None,
  121. ):
  122. super().__init__()
  123. hidden_size = config.hidden_size
  124. self.dense_h_to_4h = ColumnParallelLinear(
  125. hidden_size,
  126. 4 * hidden_size,
  127. linear_method=linear_method,
  128. )
  129. quant_config = getattr(linear_method, "quant_config", None)
  130. self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
  131. self.dense_4h_to_h = RowParallelLinear(
  132. 4 * hidden_size,
  133. hidden_size,
  134. linear_method=linear_method,
  135. )
  136. def forward(self, x: torch.Tensor) -> torch.Tensor:
  137. x, _ = self.dense_h_to_4h(x)
  138. x = self.gelu_impl(x)
  139. x, _ = self.dense_4h_to_h(x)
  140. return x
  141. class BloomBlock(nn.Module):
  142. def __init__(
  143. self,
  144. config: BloomConfig,
  145. linear_method: Optional[LinearMethodBase] = None,
  146. ):
  147. super().__init__()
  148. hidden_size = config.hidden_size
  149. self.input_layernorm = nn.LayerNorm(hidden_size,
  150. eps=config.layer_norm_epsilon)
  151. self.self_attention = BloomAttention(config, linear_method)
  152. self.post_attention_layernorm = nn.LayerNorm(
  153. hidden_size, eps=config.layer_norm_epsilon)
  154. self.mlp = BloomMLP(config, linear_method)
  155. self.apply_residual_connection_post_layernorm = (
  156. config.apply_residual_connection_post_layernorm)
  157. def forward(
  158. self,
  159. position_ids: torch.Tensor,
  160. hidden_states: torch.Tensor,
  161. kv_cache: KVCache,
  162. input_metadata: InputMetadata,
  163. ) -> torch.Tensor:
  164. # Layer norm at the beginning of the transformer layer.
  165. layernorm_output = self.input_layernorm(hidden_states)
  166. # Layer norm post the self attention.
  167. if self.apply_residual_connection_post_layernorm:
  168. residual = layernorm_output
  169. else:
  170. residual = hidden_states
  171. # Self attention.
  172. attention_output = self.self_attention(
  173. position_ids=position_ids,
  174. hidden_states=layernorm_output,
  175. kv_cache=kv_cache,
  176. input_metadata=input_metadata,
  177. )
  178. attention_output = attention_output + residual
  179. layernorm_output = self.post_attention_layernorm(attention_output)
  180. # Get residual
  181. if self.apply_residual_connection_post_layernorm:
  182. residual = layernorm_output
  183. else:
  184. residual = attention_output
  185. # MLP.
  186. output = self.mlp(layernorm_output) + residual
  187. return output
  188. class BloomModel(nn.Module):
  189. def __init__(
  190. self,
  191. config: BloomConfig,
  192. linear_method: Optional[LinearMethodBase] = None,
  193. ):
  194. super().__init__()
  195. self.embed_dim = config.hidden_size
  196. # Embedding + LN Embedding
  197. self.word_embeddings = VocabParallelEmbedding(
  198. config.vocab_size, self.embed_dim, linear_method=linear_method)
  199. self.word_embeddings_layernorm = nn.LayerNorm(
  200. self.embed_dim, eps=config.layer_norm_epsilon)
  201. # Transformer blocks
  202. self.h = nn.ModuleList([
  203. BloomBlock(config, linear_method)
  204. for _ in range(config.num_hidden_layers)
  205. ])
  206. # Final Layer Norm
  207. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  208. def forward(
  209. self,
  210. input_ids: torch.Tensor,
  211. position_ids: torch.Tensor,
  212. kv_caches: List[KVCache],
  213. input_metadata: InputMetadata,
  214. ) -> torch.Tensor:
  215. hidden_states = self.word_embeddings(input_ids)
  216. hidden_states = self.word_embeddings_layernorm(hidden_states)
  217. for i in range(len(self.h)):
  218. layer = self.h[i]
  219. hidden_states = layer(
  220. position_ids,
  221. hidden_states,
  222. kv_caches[i],
  223. input_metadata,
  224. )
  225. hidden_states = self.ln_f(hidden_states)
  226. return hidden_states
  227. class BloomForCausalLM(nn.Module):
  228. def __init__(
  229. self,
  230. config: BloomConfig,
  231. linear_method: Optional[LinearMethodBase] = None,
  232. ):
  233. super().__init__()
  234. self.config = config
  235. self.linear_method = linear_method
  236. self.transformer = BloomModel(config, linear_method)
  237. # self.lm_head_weight = self.transformer.word_embeddings.weight
  238. self.lm_head = ParallelLMHead(config.vocab_size,
  239. config.hidden_size,
  240. linear_method=linear_method)
  241. self.sampler = Sampler(config.vocab_size)
  242. self.quant_sampler = QuantSampler(config.vocab_size)
  243. def forward(
  244. self,
  245. input_ids: torch.Tensor,
  246. positions: torch.Tensor,
  247. kv_caches: List[KVCache],
  248. input_metadata: InputMetadata,
  249. ) -> torch.Tensor:
  250. hidden_states = self.transformer(input_ids, positions, kv_caches,
  251. input_metadata)
  252. return hidden_states
  253. def sample(
  254. self,
  255. hidden_states: torch.Tensor,
  256. sampling_metadata: SamplingMetadata,
  257. ) -> Optional[SamplerOutput]:
  258. if (self.linear_method is not None
  259. and not self.linear_method.quant_config.merge_weight()):
  260. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  261. sampling_metadata)
  262. else:
  263. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  264. sampling_metadata)
  265. return next_tokens
  266. def load_weights(self,
  267. model_name_or_path: str,
  268. cache_dir: Optional[str] = None,
  269. load_format: str = "auto",
  270. revision: Optional[str] = None):
  271. params_dict = dict(self.named_parameters(remove_duplicate=False))
  272. for name, loaded_weight in hf_model_weights_iterator(
  273. model_name_or_path, cache_dir, load_format, revision,
  274. self.config):
  275. if "lm_head" in name and name not in params_dict:
  276. continue
  277. if not name.startswith("transformer."):
  278. name = "transformer." + name
  279. param = params_dict[name]
  280. if "word_embeddings" in name:
  281. # Copy word embedding to lm_head
  282. head_name = name.replace("transformer.word_embeddings",
  283. "lm_head")
  284. if head_name in params_dict:
  285. lm_head_param = params_dict[head_name]
  286. weight_loader = getattr(lm_head_param, "weight_loader",
  287. default_weight_loader)
  288. weight_loader(lm_head_param, loaded_weight)
  289. if "query_key_value" in name:
  290. # NOTE: BLOOM's fused QKV's output_dim has the shape of
  291. # (num_heads * 3 * head_size), while the
  292. # required shape is (3 * num_heads * head_size).
  293. # Thus, we need weight conversion.
  294. output_dim = getattr(param, "output_dim", None)
  295. num_heads = self.config.num_attention_heads
  296. if output_dim is not None:
  297. loaded_weight_shape = loaded_weight.shape
  298. loaded_weight = loaded_weight.view(
  299. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  300. loaded_weight_shape[output_dim + 1:])
  301. loaded_weight = loaded_weight.transpose(
  302. output_dim, output_dim + 1)
  303. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  304. weight_loader = getattr(param, "weight_loader",
  305. default_weight_loader)
  306. weight_loader(param, loaded_weight)