# coding=utf-8 # Adapted from # https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/modeling_bitnet.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Bitnet model compatible with HuggingFace weights.""" # ruff: noqa: E501 from typing import Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers.configuration_utils import PretrainedConfig from loguru import logger from aphrodite.attention import Attention, AttentionMetadata from aphrodite.common.config import CacheConfig from aphrodite.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.quantization.base_config import ( QuantizationConfig) from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.common.sequence import SamplerOutput from aphrodite.common.utils import is_hip, print_warning_once class BitnetConfig(PretrainedConfig): model_type = "llama" keys_to_ignore_at_inference = ["past_key_values"] def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, weight_bits=1, input_bits=8, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.weight_bits = weight_bits self.input_bits = input_bits super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ if self.rope_scaling is None: return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in [ "linear", "dynamic" ]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if (rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0): raise ValueError( f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" ) class BitnetMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, config: BitnetConfig = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, bias=bias, quant_config=quant_config, ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() self.ffn_layernorm = RMSNorm(intermediate_size, eps=config.rms_norm_eps) def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x = self.ffn_layernorm(x) x, _ = self.down_proj(x) return x class BitnetAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, float]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, config: BitnetConfig = None, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_groups = self.num_heads // self.num_kv_heads self.head_dim = hidden_size // self.total_num_heads self.padded_head_dim = self.find_flash_attn_supported_head_dims( self.head_dim) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.attention_dropout = config.attention_dropout self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, ) self.attn = Attention( self.num_heads, self.padded_head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.inner_attn_ln = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ Find the closest head dimension to the given head dimension that is supported by Flash Attention. """ from aphrodite.attention.backends.flash_attn import FlashAttentionBackend FLASHATTN_SUPPORTED_HEAD_DIMS = ( FlashAttentionBackend.get_supported_head_sizes()) for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: if head_dim <= supported_head_dim: return supported_head_dim raise ValueError( f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.") def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: # QKV projection cannot be grouped as the they # do not share the same scaling factor qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) # Padding as paged attention doesn't support head_dim == 100 q = torch.nn.functional.pad( q.view(-1, self.total_num_heads, self.head_dim), (0, self.padded_head_dim - self.head_dim), ).view(-1, self.total_num_heads * self.padded_head_dim) k = torch.nn.functional.pad( k.view(-1, self.num_kv_heads, self.head_dim), (0, self.padded_head_dim - self.head_dim), ).view(-1, self.total_num_kv_heads * self.padded_head_dim) v = torch.nn.functional.pad( v.view(-1, self.num_kv_heads, self.head_dim), (0, self.padded_head_dim - self.head_dim), ).view(-1, self.total_num_kv_heads * self.padded_head_dim) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = attn_output.view( -1, self.total_num_heads, self.padded_head_dim)[..., :self.head_dim].reshape( -1, self.total_num_heads * self.head_dim) attn_output = self.inner_attn_ln(attn_output) output, _ = self.o_proj(attn_output) return output class BitnetDecoderLayer(nn.Module): def __init__( self, config: BitnetConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = BitnetAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, cache_config=cache_config, config=config, ) self.mlp = BitnetMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), config=config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class BitnetModel(nn.Module): def __init__( self, config: BitnetConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ BitnetDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], attn_metadata, residual, ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class BitnetForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] def __init__( self, config: BitnetConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.model = BitnetModel(config, cache_config, quant_config) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader # align scaling attr with param if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(param.data.shape) weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: print_warning_once( f"Found kv scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_kv_scale_name}). kv-scale is " "not loaded.") continue else: name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) # align scaling attr with param if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(param.data.shape) weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( quantization_param_path, tp_rank, tp_size, self.config.num_hidden_layers, self.config.__class__.model_type, ): layer_self_attn = self.model.layers[layer_idx].self_attn if is_hip(): # The scaling factor convention we are assuming is # quantized_value * scaling_factor ~= true_value # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.attn._kv_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!")