persimmon.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
  3. # Copyright 2023 The PygmalionAI team.
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only persimmon model compatible with HuggingFace weights."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import PersimmonConfig
  28. from transformers.activations import ReLUSquaredActivation
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.common.utils import progress_bar
  33. from aphrodite.distributed import get_tensor_model_parallel_world_size
  34. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead, VocabParallelEmbedding)
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. class PersimmonMLP(nn.Module):
  46. def __init__(self,
  47. config: PersimmonConfig,
  48. quant_config: Optional[QuantizationConfig] = None):
  49. super().__init__()
  50. self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
  51. config.intermediate_size,
  52. quant_config=quant_config)
  53. self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
  54. config.hidden_size,
  55. quant_config=quant_config)
  56. self.act = ReLUSquaredActivation()
  57. def forward(self, hidden_states) -> torch.Tensor:
  58. hidden_states, _ = self.dense_h_to_4h(hidden_states)
  59. hidden_states = self.act(hidden_states)
  60. hidden_states, _ = self.dense_4h_to_h(hidden_states)
  61. return hidden_states
  62. class PersimmonAttention(nn.Module):
  63. def __init__(self,
  64. config: PersimmonConfig,
  65. cache_config: Optional[CacheConfig] = None,
  66. quant_config: Optional[QuantizationConfig] = None):
  67. super().__init__()
  68. self.config = config
  69. tensor_parallel_world_size = get_tensor_model_parallel_world_size()
  70. self.hidden_size = config.hidden_size
  71. self.total_num_heads = config.num_attention_heads
  72. self.num_heads = self.total_num_heads // tensor_parallel_world_size
  73. self.head_dim = self.hidden_size // self.total_num_heads
  74. self.max_position_embeddings = config.max_position_embeddings
  75. self.rope_theta = config.rope_theta
  76. self.partial_rotary_factor = config.partial_rotary_factor
  77. self.is_causal = True
  78. assert (self.head_dim * self.total_num_heads) == self.hidden_size
  79. assert self.total_num_heads % tensor_parallel_world_size == 0
  80. self.query_key_value = QKVParallelLinear(
  81. self.hidden_size,
  82. self.head_dim,
  83. self.total_num_heads,
  84. bias=True,
  85. quant_config=quant_config,
  86. )
  87. self.dense = RowParallelLinear(
  88. self.num_heads * self.head_dim,
  89. self.hidden_size,
  90. bias=True,
  91. quant_config=quant_config,
  92. )
  93. self.is_qk_layernorm = config.qk_layernorm
  94. if self.is_qk_layernorm:
  95. self.q_layernorm = nn.LayerNorm(self.head_dim)
  96. self.k_layernorm = nn.LayerNorm(self.head_dim)
  97. self.rotary_emb = get_rope(
  98. self.head_dim,
  99. rotary_dim=int(self.partial_rotary_factor * self.head_dim),
  100. max_position=self.max_position_embeddings,
  101. base=self.rope_theta,
  102. )
  103. self.scaling = self.head_dim**-0.5
  104. self.attn = Attention(self.num_heads,
  105. self.head_dim,
  106. scale=self.scaling,
  107. cache_config=cache_config,
  108. quant_config=quant_config)
  109. def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
  110. # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
  111. seq_length = x.shape[0]
  112. return x.view(seq_length, self.num_heads, self.head_dim)
  113. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  114. # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
  115. seq_length = x.shape[0]
  116. return x.view(seq_length, self.num_heads * self.head_dim)
  117. def forward(
  118. self,
  119. position_ids: torch.Tensor,
  120. hidden_states: torch.Tensor,
  121. kv_cache: torch.Tensor,
  122. attn_metadata: AttentionMetadata,
  123. ) -> torch.Tensor:
  124. # [seq_length, 3 x hidden_size]
  125. qkv, _ = self.query_key_value(hidden_states)
  126. q, k, v = qkv.chunk(chunks=3, dim=-1)
  127. if self.is_qk_layernorm:
  128. # [seq_length, num_heads, head_dim]
  129. q = self._split_heads(q)
  130. k = self._split_heads(k)
  131. q = self.q_layernorm(q)
  132. k = self.k_layernorm(k)
  133. q = self._merge_heads(q)
  134. k = self._merge_heads(k)
  135. q, k = self.rotary_emb(position_ids, q, k)
  136. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  137. output, _ = self.dense(attn_output)
  138. return output
  139. class PersimmonDecoderLayer(nn.Module):
  140. def __init__(self,
  141. config: PersimmonConfig,
  142. cache_config: Optional[CacheConfig] = None,
  143. quant_config: Optional[QuantizationConfig] = None):
  144. super().__init__()
  145. self.hidden_size = config.hidden_size
  146. self.self_attn = PersimmonAttention(config=config,
  147. cache_config=cache_config,
  148. quant_config=quant_config)
  149. self.mlp = PersimmonMLP(config, quant_config=quant_config)
  150. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  151. eps=config.layer_norm_eps)
  152. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  153. eps=config.layer_norm_eps)
  154. def forward(
  155. self,
  156. position_ids: torch.Tensor,
  157. hidden_states: torch.Tensor,
  158. kv_cache: torch.Tensor,
  159. attn_metadata: AttentionMetadata,
  160. ) -> torch.Tensor:
  161. residual = hidden_states
  162. hidden_states = self.input_layernorm(hidden_states)
  163. # Self Attention
  164. hidden_states = self.self_attn(
  165. position_ids=position_ids,
  166. hidden_states=hidden_states,
  167. kv_cache=kv_cache,
  168. attn_metadata=attn_metadata,
  169. )
  170. hidden_states = residual + hidden_states
  171. # Fully Connected
  172. residual = hidden_states
  173. hidden_states = self.post_attention_layernorm(hidden_states)
  174. hidden_states = self.mlp(hidden_states)
  175. hidden_states = hidden_states + residual
  176. outputs = hidden_states
  177. return outputs
  178. class PersimmonModel(nn.Module):
  179. def __init__(self,
  180. config: PersimmonConfig,
  181. cache_config: Optional[CacheConfig] = None,
  182. quant_config: Optional[QuantizationConfig] = None):
  183. super().__init__()
  184. self.vocab_size = config.vocab_size
  185. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  186. config.hidden_size)
  187. self.layers = nn.ModuleList([
  188. PersimmonDecoderLayer(config,
  189. cache_config=cache_config,
  190. quant_config=quant_config)
  191. for _ in range(config.num_hidden_layers)
  192. ])
  193. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  194. eps=config.layer_norm_eps)
  195. def forward(
  196. self,
  197. input_ids: torch.Tensor,
  198. positions: torch.Tensor,
  199. kv_caches: List[torch.Tensor],
  200. attn_metadata: AttentionMetadata,
  201. inputs_embeds: Optional[torch.Tensor] = None,
  202. ) -> torch.Tensor:
  203. if inputs_embeds is not None:
  204. hidden_states = inputs_embeds
  205. else:
  206. hidden_states = self.embed_tokens(input_ids)
  207. for i in range(len(self.layers)):
  208. hidden_states = self.layers[i](
  209. positions,
  210. hidden_states,
  211. kv_caches[i],
  212. attn_metadata,
  213. )
  214. hidden_states = self.final_layernorm(hidden_states)
  215. return hidden_states
  216. class PersimmonForCausalLM(nn.Module):
  217. def __init__(self,
  218. config,
  219. cache_config: Optional[CacheConfig] = None,
  220. quant_config: Optional[QuantizationConfig] = None):
  221. super().__init__()
  222. self.config = config
  223. self.vocab_size = config.vocab_size
  224. self.model = PersimmonModel(config,
  225. cache_config=cache_config,
  226. quant_config=quant_config)
  227. self.lm_head = ParallelLMHead(config.vocab_size,
  228. config.hidden_size,
  229. bias=False)
  230. self.logits_processor = LogitsProcessor(config.vocab_size)
  231. self.sampler = Sampler()
  232. def forward(
  233. self,
  234. input_ids: torch.Tensor,
  235. positions: torch.Tensor,
  236. kv_caches: List[torch.Tensor],
  237. attn_metadata: AttentionMetadata,
  238. intermediate_tensors: Optional[IntermediateTensors] = None,
  239. inputs_embeds: Optional[torch.Tensor] = None,
  240. ):
  241. hidden_states = self.model(
  242. input_ids=input_ids,
  243. positions=positions,
  244. kv_caches=kv_caches,
  245. attn_metadata=attn_metadata,
  246. inputs_embeds=inputs_embeds,
  247. )
  248. return hidden_states
  249. def compute_logits(
  250. self,
  251. hidden_states: torch.Tensor,
  252. sampling_metadata: SamplingMetadata,
  253. ) -> Optional[torch.Tensor]:
  254. logits = self.logits_processor(self.lm_head, hidden_states,
  255. sampling_metadata)
  256. return logits
  257. def sample(
  258. self,
  259. logits: torch.Tensor,
  260. sampling_metadata: SamplingMetadata,
  261. ) -> Optional[SamplerOutput]:
  262. next_tokens = self.sampler(logits, sampling_metadata)
  263. return next_tokens
  264. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  265. params_dict = dict(self.named_parameters(remove_duplicate=False))
  266. weights_list = list(weights)
  267. for name, loaded_weight in progress_bar(weights_list,
  268. desc="Loading modules..."):
  269. if "rotary_emb.inv_freq" in name:
  270. continue
  271. if ("rotary_emb.cos_cached" in name
  272. or "rotary_emb.sin_cached" in name):
  273. # Models trained using ColossalAI may include these tensors in
  274. # the checkpoint. Skip them.
  275. continue
  276. param = params_dict[name]
  277. if "query_key_value" in name:
  278. # copy from vllm/model_executor/models/bloom.py
  279. # NOTE: Persimmon's fused QKV's output_dim has the shape of
  280. # (num_heads * 3 * head_size), while the
  281. # required shape is (3 * num_heads * head_size).
  282. # Thus, we need weight conversion.
  283. output_dim = getattr(param, "output_dim", None)
  284. num_heads = self.config.num_attention_heads
  285. if output_dim is not None:
  286. loaded_weight_shape = loaded_weight.shape
  287. loaded_weight = loaded_weight.view(
  288. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  289. loaded_weight_shape[output_dim + 1:])
  290. loaded_weight = loaded_weight.transpose(
  291. output_dim, output_dim + 1)
  292. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  293. weight_loader = getattr(param, "weight_loader",
  294. default_weight_loader)
  295. weight_loader(param, loaded_weight)