commandr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn.parameter import Parameter
  27. from transformers import CohereConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.sequence import SamplerOutput
  30. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  31. get_tensor_model_parallel_world_size)
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  34. MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  41. VocabParallelEmbedding
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.modeling.utils import set_weight_attrs
  45. @torch.compile
  46. def layer_norm_func(hidden_states, weight, variance_epsilon):
  47. input_dtype = hidden_states.dtype
  48. hidden_states = hidden_states.to(torch.float32)
  49. mean = hidden_states.mean(-1, keepdim=True)
  50. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  51. hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
  52. variance_epsilon)
  53. hidden_states = weight.to(torch.float32) * hidden_states
  54. return hidden_states.to(input_dtype)
  55. class LayerNorm(nn.Module):
  56. def __init__(self, param_shape=None, eps=1e-5):
  57. super().__init__()
  58. self.weight = nn.Parameter(torch.ones(param_shape))
  59. self.variance_epsilon = eps
  60. set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
  61. def forward(self, hidden_states, residuals=None):
  62. hidden_states = layer_norm_func(hidden_states, self.weight,
  63. self.variance_epsilon)
  64. return hidden_states, residuals
  65. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  66. tp_rank = get_tensor_model_parallel_rank()
  67. shard_dim = 0 if param.dim() != 1 else None
  68. param_data = param.data
  69. if shard_dim is not None:
  70. shard_size = param_data.shape[shard_dim]
  71. start_idx = tp_rank * shard_size
  72. loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
  73. shard_size)
  74. assert param_data.shape == loaded_weight.shape
  75. param_data.copy_(loaded_weight)
  76. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  77. class CohereMLP(nn.Module):
  78. def __init__(
  79. self,
  80. config,
  81. linear_method: Optional[LinearMethodBase] = None,
  82. ):
  83. super().__init__()
  84. self.config = config
  85. self.hidden_size = config.hidden_size
  86. self.intermediate_size = config.intermediate_size
  87. self.gate_up_proj = MergedColumnParallelLinear(
  88. self.hidden_size,
  89. [self.intermediate_size] * 2,
  90. bias=False,
  91. linear_method=linear_method,
  92. )
  93. self.down_proj = RowParallelLinear(
  94. self.intermediate_size,
  95. self.hidden_size,
  96. bias=False,
  97. linear_method=linear_method,
  98. )
  99. self.act_fn = SiluAndMul()
  100. def forward(self, x):
  101. gate_up, _ = self.gate_up_proj(x)
  102. x = self.act_fn(gate_up)
  103. x, _ = self.down_proj(x)
  104. return x
  105. class CohereAttention(nn.Module):
  106. def __init__(
  107. self,
  108. config: CohereConfig,
  109. linear_method: Optional[LinearMethodBase] = None,
  110. ):
  111. super().__init__()
  112. tp_size = get_tensor_model_parallel_world_size()
  113. self.config = config
  114. self.attention_dropout = config.attention_dropout
  115. self.hidden_size = config.hidden_size
  116. self.total_num_heads = config.num_attention_heads
  117. self.num_heads = self.total_num_heads // tp_size
  118. self.head_dim = self.hidden_size // self.total_num_heads
  119. self.total_num_kv_heads = config.num_key_value_heads
  120. if self.total_num_kv_heads >= tp_size:
  121. # Number of KV heads is greater than TP size, so we partition
  122. # the KV heads across multiple tensor parallel GPUs.
  123. assert self.total_num_kv_heads % tp_size == 0
  124. else:
  125. # Number of KV heads is less than TP size, so we replicate
  126. # the KV heads across multiple tensor parallel GPUs.
  127. assert tp_size % self.total_num_kv_heads == 0
  128. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  129. self.q_size = self.num_heads * self.head_dim
  130. self.kv_size = self.num_kv_heads * self.head_dim
  131. self.scaling = self.head_dim**-0.5
  132. self.max_position_embeddings = getattr(
  133. config, "model_max_length", None) or getattr(
  134. config, "max_position_embeddings", 8192)
  135. self.rope_theta = config.rope_theta
  136. self.rope_scaling = getattr(config, "rope_scaling", None)
  137. self.use_qk_norm = getattr(config, "use_qk_norm", False)
  138. self.qkv_proj = QKVParallelLinear(
  139. self.hidden_size,
  140. self.head_dim,
  141. self.total_num_heads,
  142. self.total_num_kv_heads,
  143. bias=False,
  144. linear_method=linear_method,
  145. )
  146. self.o_proj = RowParallelLinear(
  147. self.total_num_heads * self.head_dim,
  148. self.hidden_size,
  149. bias=False,
  150. linear_method=linear_method,
  151. )
  152. self.rotary_emb = get_rope(
  153. self.head_dim,
  154. rotary_dim=self.head_dim,
  155. max_position=self.max_position_embeddings,
  156. base=self.rope_theta,
  157. rope_scaling=self.rope_scaling,
  158. is_neox_style=False,
  159. )
  160. self.attn = Attention(
  161. self.num_heads,
  162. self.head_dim,
  163. self.scaling,
  164. num_kv_heads=self.num_kv_heads,
  165. )
  166. if self.use_qk_norm:
  167. self.q_norm = LayerNorm(param_shape=(self.num_heads,
  168. self.head_dim),
  169. eps=config.layer_norm_eps)
  170. self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
  171. self.head_dim),
  172. eps=config.layer_norm_eps)
  173. def _apply_qk_norm(self, q, k):
  174. q = q.view(*q.shape[:-1], -1, self.head_dim)
  175. k = k.view(*k.shape[:-1], -1, self.head_dim)
  176. q, _ = self.q_norm(q)
  177. k, _ = self.k_norm(k)
  178. q = q.view(*q.shape[:-2], -1)
  179. k = k.view(*k.shape[:-2], -1)
  180. return q, k
  181. def forward(
  182. self,
  183. positions: torch.Tensor,
  184. hidden_states: torch.Tensor,
  185. kv_cache: torch.Tensor,
  186. attn_metadata: AttentionMetadata,
  187. ) -> torch.Tensor:
  188. qkv, _ = self.qkv_proj(hidden_states)
  189. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  190. if self.use_qk_norm:
  191. q, k = self._apply_qk_norm(q, k)
  192. q, k = self.rotary_emb(positions, q, k)
  193. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  194. output, _ = self.o_proj(attn_output)
  195. return output
  196. class CohereDecoderLayer(nn.Module):
  197. def __init__(self,
  198. config: CohereConfig,
  199. linear_method: Optional[LinearMethodBase] = None):
  200. super().__init__()
  201. self.hidden_size = config.hidden_size
  202. self.self_attn = CohereAttention(config, linear_method=linear_method)
  203. self.mlp = CohereMLP(config, linear_method=linear_method)
  204. self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
  205. eps=config.layer_norm_eps)
  206. def forward(
  207. self,
  208. positions: torch.Tensor,
  209. hidden_states: torch.Tensor,
  210. kv_cache: torch.Tensor,
  211. attn_metadata: AttentionMetadata,
  212. residual: Optional[torch.Tensor],
  213. ) -> Tuple[torch.Tensor, torch.Tensor]:
  214. # Self Attention
  215. residual = hidden_states
  216. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  217. hidden_states_attention = self.self_attn(
  218. positions=positions,
  219. hidden_states=hidden_states,
  220. kv_cache=kv_cache,
  221. attn_metadata=attn_metadata,
  222. )
  223. hidden_states_mlp = self.mlp(hidden_states)
  224. # Add everything together
  225. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  226. return hidden_states, residual
  227. class CohereModel(nn.Module):
  228. def __init__(
  229. self,
  230. config: CohereConfig,
  231. linear_method: Optional[LinearMethodBase] = None,
  232. ):
  233. super().__init__()
  234. self.config = config
  235. self.vocab_size = config.vocab_size
  236. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  237. config.hidden_size)
  238. self.layers = nn.ModuleList([
  239. CohereDecoderLayer(config, linear_method=linear_method)
  240. for _ in range(config.num_hidden_layers)
  241. ])
  242. self.norm = LayerNorm(param_shape=(config.hidden_size),
  243. eps=config.layer_norm_eps)
  244. def forward(
  245. self,
  246. input_ids: torch.Tensor,
  247. positions: torch.Tensor,
  248. kv_caches: List[torch.Tensor],
  249. attn_metadata: AttentionMetadata,
  250. ) -> torch.Tensor:
  251. hidden_states = self.embed_tokens(input_ids)
  252. residual = None
  253. for i in range(len(self.layers)):
  254. layer = self.layers[i]
  255. hidden_states, residual = layer(
  256. positions,
  257. hidden_states,
  258. kv_caches[i],
  259. attn_metadata,
  260. residual,
  261. )
  262. hidden_states, _ = self.norm(hidden_states, residual)
  263. return hidden_states
  264. class CohereForCausalLM(nn.Module):
  265. def __init__(
  266. self,
  267. config: CohereConfig,
  268. linear_method: Optional[LinearMethodBase] = None,
  269. ) -> None:
  270. super().__init__()
  271. self.config = config
  272. self.linear_method = linear_method
  273. self.logits_processor = LogitsProcessor(config.vocab_size,
  274. scale=config.logit_scale)
  275. self.model = CohereModel(config, linear_method)
  276. self.sampler = Sampler()
  277. @torch.no_grad()
  278. def forward(
  279. self,
  280. input_ids: torch.Tensor,
  281. positions: torch.Tensor,
  282. kv_caches: List[torch.Tensor],
  283. attn_metadata: AttentionMetadata,
  284. ) -> torch.Tensor:
  285. hidden_states = self.model(input_ids, positions, kv_caches,
  286. attn_metadata)
  287. return hidden_states
  288. def compute_logits(self, hidden_states: torch.Tensor,
  289. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  290. logits = self.logits_processor(self.model.embed_tokens.weight,
  291. hidden_states, sampling_metadata)
  292. return logits
  293. def sample(
  294. self,
  295. logits: torch.Tensor,
  296. sampling_metadata: SamplingMetadata,
  297. ) -> Optional[SamplerOutput]:
  298. next_tokens = self.sampler(logits, sampling_metadata)
  299. return next_tokens
  300. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  301. stacked_params_mapping = [
  302. # (param_name, shard_name, shard_id)
  303. ("qkv_proj", "q_proj", "q"),
  304. ("qkv_proj", "k_proj", "k"),
  305. ("qkv_proj", "v_proj", "v"),
  306. ("gate_up_proj", "gate_proj", 0),
  307. ("gate_up_proj", "up_proj", 1),
  308. ]
  309. params_dict = dict(self.named_parameters())
  310. loaded_params = set()
  311. for name, loaded_weight in weights:
  312. for param_name, shard_name, shard_id in stacked_params_mapping:
  313. if shard_name not in name:
  314. continue
  315. name = name.replace(shard_name, param_name)
  316. # Skip loading extra bias for GPTQ models.
  317. if name.endswith(".bias") and name not in params_dict:
  318. continue
  319. param = params_dict[name]
  320. weight_loader = param.weight_loader
  321. weight_loader(param, loaded_weight, shard_id)
  322. break
  323. else:
  324. # lm_head is not used in Aphrodite as it is tied with
  325. # embed_token. To prevent errors, skip loading lm_head.weight.
  326. if "lm_head.weight" in name:
  327. continue
  328. # Skip loading extra bias for GPTQ models.
  329. if name.endswith(".bias") and name not in params_dict:
  330. continue
  331. param = params_dict[name]
  332. weight_loader = getattr(param, "weight_loader",
  333. default_weight_loader)
  334. weight_loader(param, loaded_weight)
  335. loaded_params.add(name)