persimmon.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
  3. # Copyright 2023 The PygmalionAI team.
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only persimmon model compatible with HuggingFace weights."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import PersimmonConfig
  28. from transformers.activations import ReLUSquaredActivation
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig
  31. from aphrodite.common.sequence import IntermediateTensors
  32. from aphrodite.distributed import get_tensor_model_parallel_world_size
  33. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.quantization.base_config import QuantizationConfig
  44. class PersimmonMLP(nn.Module):
  45. def __init__(self,
  46. config: PersimmonConfig,
  47. quant_config: Optional[QuantizationConfig] = None):
  48. super().__init__()
  49. self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
  50. config.intermediate_size,
  51. quant_config=quant_config)
  52. self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
  53. config.hidden_size,
  54. quant_config=quant_config)
  55. self.act = ReLUSquaredActivation()
  56. def forward(self, hidden_states) -> torch.Tensor:
  57. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  58. hidden_states = self.act(hidden_states)
  59. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  60. return hidden_states
  61. class PersimmonAttention(nn.Module):
  62. def __init__(self,
  63. config: PersimmonConfig,
  64. cache_config: Optional[CacheConfig] = None,
  65. quant_config: Optional[QuantizationConfig] = None):
  66. super().__init__()
  67. self.config = config
  68. tensor_parallel_world_size = get_tensor_model_parallel_world_size()
  69. self.hidden_size = config.hidden_size
  70. self.total_num_heads = config.num_attention_heads
  71. self.num_heads = self.total_num_heads // tensor_parallel_world_size
  72. self.head_dim = self.hidden_size // self.total_num_heads
  73. self.max_position_embeddings = config.max_position_embeddings
  74. self.rope_theta = config.rope_theta
  75. self.partial_rotary_factor = config.partial_rotary_factor
  76. self.is_causal = True
  77. assert (self.head_dim * self.total_num_heads) == self.hidden_size
  78. assert self.total_num_heads % tensor_parallel_world_size == 0
  79. self.query_key_value = QKVParallelLinear(
  80. self.hidden_size,
  81. self.head_dim,
  82. self.total_num_heads,
  83. bias=True,
  84. quant_config=quant_config,
  85. )
  86. self.dense = RowParallelLinear(
  87. self.num_heads * self.head_dim,
  88. self.hidden_size,
  89. bias=True,
  90. quant_config=quant_config,
  91. )
  92. self.is_qk_layernorm = config.qk_layernorm
  93. if self.is_qk_layernorm:
  94. self.q_layernorm = nn.LayerNorm(self.head_dim)
  95. self.k_layernorm = nn.LayerNorm(self.head_dim)
  96. self.rotary_emb = get_rope(
  97. self.head_dim,
  98. rotary_dim=int(self.partial_rotary_factor * self.head_dim),
  99. max_position=self.max_position_embeddings,
  100. base=self.rope_theta,
  101. )
  102. self.scaling = self.head_dim**-0.5
  103. self.attn = Attention(self.num_heads,
  104. self.head_dim,
  105. scale=self.scaling,
  106. cache_config=cache_config,
  107. quant_config=quant_config)
  108. def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
  109. # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
  110. seq_length = x.shape[0]
  111. return x.view(seq_length, self.num_heads, self.head_dim)
  112. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  113. # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
  114. seq_length = x.shape[0]
  115. return x.view(seq_length, self.num_heads * self.head_dim)
  116. def forward(
  117. self,
  118. position_ids: torch.Tensor,
  119. hidden_states: torch.Tensor,
  120. kv_cache: torch.Tensor,
  121. attn_metadata: AttentionMetadata,
  122. ) -> torch.Tensor:
  123. # [seq_length, 3 x hidden_size]
  124. qkv, _ = self.query_key_value(hidden_states)
  125. q, k, v = qkv.chunk(chunks=3, dim=-1)
  126. if self.is_qk_layernorm:
  127. # [seq_length, num_heads, head_dim]
  128. q = self._split_heads(q)
  129. k = self._split_heads(k)
  130. q = self.q_layernorm(q)
  131. k = self.k_layernorm(k)
  132. q = self._merge_heads(q)
  133. k = self._merge_heads(k)
  134. q, k = self.rotary_emb(position_ids, q, k)
  135. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  136. output, _ = self.dense(attn_output)
  137. return output
  138. class PersimmonDecoderLayer(nn.Module):
  139. def __init__(self,
  140. config: PersimmonConfig,
  141. cache_config: Optional[CacheConfig] = None,
  142. quant_config: Optional[QuantizationConfig] = None):
  143. super().__init__()
  144. self.hidden_size = config.hidden_size
  145. self.self_attn = PersimmonAttention(config=config,
  146. cache_config=cache_config,
  147. quant_config=quant_config)
  148. self.mlp = PersimmonMLP(config, quant_config=quant_config)
  149. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  150. eps=config.layer_norm_eps)
  151. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  152. eps=config.layer_norm_eps)
  153. def forward(
  154. self,
  155. position_ids: torch.Tensor,
  156. hidden_states: torch.Tensor,
  157. kv_cache: torch.Tensor,
  158. attn_metadata: AttentionMetadata,
  159. ) -> torch.Tensor:
  160. residual = hidden_states
  161. hidden_states = self.input_layernorm(hidden_states)
  162. # Self Attention
  163. hidden_states = self.self_attn(
  164. position_ids=position_ids,
  165. hidden_states=hidden_states,
  166. kv_cache=kv_cache,
  167. attn_metadata=attn_metadata,
  168. )
  169. hidden_states = residual + hidden_states
  170. # Fully Connected
  171. residual = hidden_states
  172. hidden_states = self.post_attention_layernorm(hidden_states)
  173. hidden_states = self.mlp(hidden_states)
  174. hidden_states = hidden_states + residual
  175. outputs = hidden_states
  176. return outputs
  177. class PersimmonModel(nn.Module):
  178. def __init__(self,
  179. config: PersimmonConfig,
  180. cache_config: Optional[CacheConfig] = None,
  181. quant_config: Optional[QuantizationConfig] = None):
  182. super().__init__()
  183. self.vocab_size = config.text_config.vocab_size
  184. self.embed_tokens = VocabParallelEmbedding(
  185. config.text_config.vocab_size, config.hidden_size)
  186. self.layers = nn.ModuleList([
  187. PersimmonDecoderLayer(config,
  188. cache_config=cache_config,
  189. quant_config=quant_config)
  190. for _ in range(config.num_hidden_layers)
  191. ])
  192. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  193. eps=config.layer_norm_eps)
  194. def forward(
  195. self,
  196. input_ids: torch.Tensor,
  197. positions: torch.Tensor,
  198. kv_caches: List[torch.Tensor],
  199. attn_metadata: AttentionMetadata,
  200. inputs_embeds: Optional[torch.Tensor] = None,
  201. ) -> torch.Tensor:
  202. if inputs_embeds is not None:
  203. hidden_states = inputs_embeds
  204. else:
  205. hidden_states = self.embed_tokens(input_ids)
  206. for i in range(len(self.layers)):
  207. hidden_states = self.layers[i](
  208. positions,
  209. hidden_states,
  210. kv_caches[i],
  211. attn_metadata,
  212. )
  213. hidden_states = self.final_layernorm(hidden_states)
  214. return hidden_states
  215. class PersimmonForCausalLM(nn.Module):
  216. def __init__(self,
  217. config,
  218. cache_config: Optional[CacheConfig] = None,
  219. quant_config: Optional[QuantizationConfig] = None):
  220. super().__init__()
  221. self.config = config
  222. self.vocab_size = config.text_config.vocab_size
  223. self.model = PersimmonModel(config,
  224. cache_config=cache_config,
  225. quant_config=quant_config)
  226. self.lm_head = ParallelLMHead(config.text_config.vocab_size,
  227. config.hidden_size,
  228. bias=False)
  229. self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
  230. self.sampler = Sampler()
  231. def forward(
  232. self,
  233. input_ids: torch.Tensor,
  234. positions: torch.Tensor,
  235. kv_caches: List[torch.Tensor],
  236. attn_metadata: AttentionMetadata,
  237. intermediate_tensors: Optional[IntermediateTensors] = None,
  238. inputs_embeds: Optional[torch.Tensor] = None,
  239. ):
  240. hidden_states = self.model(
  241. input_ids=input_ids,
  242. positions=positions,
  243. kv_caches=kv_caches,
  244. attn_metadata=attn_metadata,
  245. inputs_embeds=inputs_embeds,
  246. )
  247. return hidden_states
  248. def compute_logits(
  249. self,
  250. hidden_states: torch.Tensor,
  251. sampling_metadata: SamplingMetadata,
  252. ) -> Optional[torch.Tensor]:
  253. logits = self.logits_processor(self.lm_head, hidden_states,
  254. sampling_metadata)
  255. return logits
  256. def sample(
  257. self,
  258. logits: torch.Tensor,
  259. sampling_metadata: SamplingMetadata,
  260. ) -> Optional[SamplerOutput]:
  261. next_tokens = self.sampler(logits, sampling_metadata)
  262. return next_tokens
  263. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  264. params_dict = dict(self.named_parameters(remove_duplicate=False))
  265. for name, loaded_weight in weights:
  266. if "rotary_emb.inv_freq" in name:
  267. continue
  268. if ("rotary_emb.cos_cached" in name
  269. or "rotary_emb.sin_cached" in name):
  270. # Models trained using ColossalAI may include these tensors in
  271. # the checkpoint. Skip them.
  272. continue
  273. param = params_dict[name]
  274. if "query_key_value" in name:
  275. # copy from vllm/model_executor/models/bloom.py
  276. # NOTE: Persimmon's fused QKV's output_dim has the shape of
  277. # (num_heads * 3 * head_size), while the
  278. # required shape is (3 * num_heads * head_size).
  279. # Thus, we need weight conversion.
  280. output_dim = getattr(param, "output_dim", None)
  281. num_heads = self.config.num_attention_heads
  282. if output_dim is not None:
  283. loaded_weight_shape = loaded_weight.shape
  284. loaded_weight = loaded_weight.view(
  285. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  286. loaded_weight_shape[output_dim + 1:])
  287. loaded_weight = loaded_weight.transpose(
  288. output_dim, output_dim + 1)
  289. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  290. weight_loader = getattr(param, "weight_loader",
  291. default_weight_loader)
  292. weight_loader(param, loaded_weight)