decilm.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 DeciAI Research Team. All rights reserved.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only DeciLM model compatible with HuggingFace weights."""
  26. from typing import Iterable, Optional, Tuple
  27. import torch
  28. from transformers import PretrainedConfig
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.modeling.layers.linear import LinearMethodBase
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.models.llama import LlamaForCausalLM
  33. class DeciLMForCausalLM(LlamaForCausalLM):
  34. """
  35. Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
  36. Based on the llama executor.
  37. The main difference is that DeciLM uses Variable Grouped Query Attention.
  38. The constant number of GQA heads in the decoder is overridden with a value
  39. per layer.
  40. Usually, in the HuggingFace implementation, instead of
  41. "config.num_key_value_heads", we use
  42. "config.num_key_value_heads_per_layer[i]" which varies.
  43. Currently, PagedAttention does not work well with variable GQA, so we
  44. normalize the weights upon loading, and use uniform GQA with the max value
  45. instead.
  46. """
  47. def __init__(
  48. self,
  49. config: Optional[PretrainedConfig] = None,
  50. linear_method: Optional[LinearMethodBase] = None,
  51. lora_config: Optional[LoRAConfig] = None,
  52. ) -> None:
  53. config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
  54. delattr(config, "num_key_value_heads_per_layer")
  55. super().__init__(config=config,
  56. linear_method=linear_method,
  57. lora_config=lora_config)
  58. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  59. stacked_params_mapping = [
  60. # (param_name, shard_name, shard_id)
  61. ("qkv_proj", "q_proj", "q"),
  62. ("qkv_proj", "k_proj", "k"),
  63. ("qkv_proj", "v_proj", "v"),
  64. ("gate_up_proj", "gate_proj", 0),
  65. ("gate_up_proj", "up_proj", 1),
  66. ]
  67. params_dict = dict(self.named_parameters())
  68. for name, loaded_weight in weights:
  69. if "rotary_emb.inv_freq" in name:
  70. continue
  71. if "k_proj" in name or "v_proj" in name:
  72. loaded_weight = self._degroup_weight(loaded_weight)
  73. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  74. if weight_name not in name:
  75. continue
  76. name = name.replace(weight_name, param_name)
  77. # Skip loading extra bias for GPTQ models.
  78. if name.endswith(".bias") and name not in params_dict:
  79. continue
  80. param = params_dict[name]
  81. weight_loader = param.weight_loader
  82. weight_loader(param, loaded_weight, shard_id)
  83. break
  84. else:
  85. # Skip loading extra bias for GPTQ models.
  86. if name.endswith(".bias") and name not in params_dict:
  87. continue
  88. param = params_dict[name]
  89. weight_loader = getattr(param, "weight_loader",
  90. default_weight_loader)
  91. weight_loader(param, loaded_weight)
  92. def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
  93. hidden_size = self.config.hidden_size
  94. head_size = self.config.hidden_size // self.config.num_attention_heads
  95. target_num_kv_heads = self.config.num_key_value_heads
  96. num_kv_heads = loaded_weight.shape[0] // head_size
  97. n_repeats = target_num_kv_heads / num_kv_heads
  98. assert n_repeats == int(n_repeats)
  99. n_repeats = int(n_repeats)
  100. loaded_weight = loaded_weight.view(num_kv_heads, head_size,
  101. hidden_size)
  102. loaded_weight = torch.repeat_interleave(loaded_weight,
  103. repeats=n_repeats,
  104. dim=0)
  105. loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
  106. hidden_size)
  107. return loaded_weight