# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # Copyright 2023 The vLLM team. # Copyright 2022 HuggingFace Inc. team and BigScience workshop. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math from typing import Iterable, List, Optional, Tuple import torch from torch import nn from transformers import BloomConfig from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig from aphrodite.common.sequence import IntermediateTensors, SamplerOutput from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.layers.activation import get_act_fn from aphrodite.modeling.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.quantization.base_config import QuantizationConfig def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( 2**(-(2**-(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, total_num_heads - closest_power_of_2) extra_powers = torch.arange(start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32) slopes = torch.cat( [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size self.total_num_heads = config.n_head self.head_dim = self.hidden_size // self.total_num_heads assert self.head_dim * self.total_num_heads == self.hidden_size tp_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, bias=True, quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, ) # Create the alibi slopes and slice them. tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, scaling, alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.dense(attn_output) return output class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, quant_config=quant_config, ) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.dense_h_to_4h(x) x = self.gelu_impl(x) x, _ = self.dense_4h_to_h(x) return x class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.self_attention = BloomAttention(config, cache_config, quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self attention. attention_output = self.self_attention( position_ids=position_ids, hidden_states=layernorm_output, kv_cache=kv_cache, attn_metadata=attn_metadata, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # MLP. output = self.mlp(layernorm_output) + residual return output class BloomModel(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.embed_dim = config.hidden_size # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( config.vocab_size, self.embed_dim, ) self.word_embeddings_layernorm = nn.LayerNorm( self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks self.h = nn.ModuleList([ BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], attn_metadata, ) hidden_states = self.ln_f(hidden_states) return hidden_states class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head = self.transformer.word_embeddings self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if name == "lm_head.weight": continue if not name.startswith("transformer."): name = "transformer." + name param = params_dict[name] if "query_key_value" in name: # NOTE: BLOOM's fused QKV's output_dim has the shape of # (num_heads * 3 * head_size), while the # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + loaded_weight_shape[output_dim + 1:]) loaded_weight = loaded_weight.transpose( output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)