123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 DeciAI Research Team. All rights reserved.
- # Copyright 2023 The vLLM team.
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Inference-only DeciLM model compatible with HuggingFace weights."""
- from typing import Optional
- import torch
- from transformers import PretrainedConfig
- from aphrodite.common.config import LoRAConfig
- from aphrodite.modeling.layers.linear import LinearMethodBase
- from aphrodite.modeling.models.llama import LlamaForCausalLM
- from aphrodite.modeling.hf_downloader import (default_weight_loader,
- hf_model_weights_iterator)
- class DeciLMForCausalLM(LlamaForCausalLM):
- """
- Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
- Based on the llama executor.
- The main difference is that DeciLM uses Variable Grouped Query Attention.
- The constant number of GQA heads in the decoder is overridden with a value
- per layer.
- Usually, in the HuggingFace implementation, instead of
- "config.num_key_value_heads", we use
- "config.num_key_value_heads_per_layer[i]" which varies.
- Currently, PagedAttention does not work well with variable GQA, so we
- normalize the weights upon loading, and use uniform GQA with the max value
- instead.
- """
- def __init__(
- self,
- config: Optional[PretrainedConfig] = None,
- linear_method: Optional[LinearMethodBase] = None,
- lora_config: Optional[LoRAConfig] = None,
- ) -> None:
- config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
- delattr(config, "num_key_value_heads_per_layer")
- super().__init__(config=config,
- linear_method=linear_method,
- lora_config=lora_config)
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision,
- self.config):
- if "rotary_emb.inv_freq" in name:
- continue
- if "k_proj" in name or "v_proj" in name:
- loaded_weight = self._degroup_weight(loaded_weight)
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
- hidden_size = self.config.hidden_size
- head_size = self.config.hidden_size // self.config.num_attention_heads
- target_num_kv_heads = self.config.num_key_value_heads
- num_kv_heads = loaded_weight.shape[0] // head_size
- n_repeats = target_num_kv_heads / num_kv_heads
- assert n_repeats == int(n_repeats)
- n_repeats = int(n_repeats)
- loaded_weight = loaded_weight.view(num_kv_heads, head_size,
- hidden_size)
- loaded_weight = torch.repeat_interleave(loaded_weight,
- repeats=n_repeats,
- dim=0)
- loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
- hidden_size)
- return loaded_weight
|