# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The CacheFlow team. # Copyright 2022 HuggingFace Inc. team and BigScience workshop. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math from typing import List, Optional, Tuple import torch from torch import nn from transformers import BloomConfig from aphrodite.modeling.metadata import InputMetadata from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.attention import PagedAttention from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.sampler import Sampler, QuantSampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from aphrodite.modeling.megatron.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.modeling.hf_downloader import (default_weight_loader, hf_model_weights_iterator) from aphrodite.common.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( 2**(-(2**-(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, total_num_heads - closest_power_of_2) extra_powers = torch.arange(start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32) slopes = torch.cat( [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = config.hidden_size self.total_num_heads = config.n_head self.head_dim = self.hidden_size // self.total_num_heads assert self.head_dim * self.total_num_heads == self.hidden_size tp_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, bias=True, linear_method=linear_method, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, linear_method=linear_method, ) # Create the alibi slopes and slice them. tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 self.attn = PagedAttention(self.num_heads, self.head_dim, scaling, alibi_slopes=alibi_slopes) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.dense(attn_output) return output class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, linear_method=linear_method, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.dense_h_to_4h(x) x = self.gelu_impl(x) x, _ = self.dense_4h_to_h(x) return x class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.self_attention = BloomAttention(config, linear_method) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, linear_method) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self attention. attention_output = self.self_attention( position_ids=position_ids, hidden_states=layernorm_output, kv_cache=kv_cache, input_metadata=input_metadata, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # MLP. output = self.mlp(layernorm_output) + residual return output class BloomModel(nn.Module): def __init__( self, config: BloomConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.embed_dim = config.hidden_size # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( config.vocab_size, self.embed_dim, linear_method=linear_method) self.word_embeddings_layernorm = nn.LayerNorm( self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks self.h = nn.ModuleList([ BloomBlock(config, linear_method) for _ in range(config.num_hidden_layers) ]) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config self.linear_method = linear_method self.transformer = BloomModel(config, linear_method) # self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, linear_method=linear_method) self.sampler = Sampler(config.vocab_size) self.quant_sampler = QuantSampler(config.vocab_size) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata) return hidden_states def sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if (self.linear_method is not None and not self.linear_method.quant_config.merge_weight()): next_tokens = self.quant_sampler(self.lm_head(hidden_states), sampling_metadata) else: next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision, self.config): if "lm_head" in name and name not in params_dict: continue if not name.startswith("transformer."): name = "transformer." + name param = params_dict[name] if "word_embeddings" in name: # Copy word embedding to lm_head head_name = name.replace("transformer.word_embeddings", "lm_head") if head_name in params_dict: lm_head_param = params_dict[head_name] weight_loader = getattr(lm_head_param, "weight_loader", default_weight_loader) weight_loader(lm_head_param, loaded_weight) if "query_key_value" in name: # NOTE: BLOOM's fused QKV's output_dim has the shape of # (num_heads * 3 * head_size), while the # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + loaded_weight_shape[output_dim + 1:]) loaded_weight = loaded_weight.transpose( output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)