commandr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn.parameter import Parameter
  27. from transformers import CohereConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size)
  33. from aphrodite.modeling.layers.activation import SiluAndMul
  34. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  41. VocabParallelEmbedding
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.modeling.utils import set_weight_attrs
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. @torch.compile
  47. def layer_norm_func(hidden_states, weight, variance_epsilon):
  48. input_dtype = hidden_states.dtype
  49. hidden_states = hidden_states.to(torch.float32)
  50. mean = hidden_states.mean(-1, keepdim=True)
  51. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  52. hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
  53. variance_epsilon)
  54. hidden_states = weight.to(torch.float32) * hidden_states
  55. return hidden_states.to(input_dtype)
  56. class LayerNorm(nn.Module):
  57. def __init__(self, param_shape=None, eps=1e-5):
  58. super().__init__()
  59. self.weight = nn.Parameter(torch.ones(param_shape))
  60. self.variance_epsilon = eps
  61. set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
  62. def forward(self, hidden_states, residuals=None):
  63. hidden_states = layer_norm_func(hidden_states, self.weight,
  64. self.variance_epsilon)
  65. return hidden_states, residuals
  66. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  67. tp_rank = get_tensor_model_parallel_rank()
  68. shard_dim = 0 if param.dim() != 1 else None
  69. param_data = param.data
  70. if shard_dim is not None:
  71. shard_size = param_data.shape[shard_dim]
  72. start_idx = tp_rank * shard_size
  73. loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
  74. shard_size)
  75. assert param_data.shape == loaded_weight.shape
  76. param_data.copy_(loaded_weight)
  77. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  78. class CohereMLP(nn.Module):
  79. def __init__(
  80. self,
  81. config,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. ):
  84. super().__init__()
  85. self.config = config
  86. self.hidden_size = config.hidden_size
  87. self.intermediate_size = config.intermediate_size
  88. self.gate_up_proj = MergedColumnParallelLinear(
  89. self.hidden_size,
  90. [self.intermediate_size] * 2,
  91. bias=False,
  92. quant_config=quant_config,
  93. )
  94. self.down_proj = RowParallelLinear(
  95. self.intermediate_size,
  96. self.hidden_size,
  97. bias=False,
  98. quant_config=quant_config,
  99. )
  100. self.act_fn = SiluAndMul()
  101. def forward(self, x):
  102. gate_up, _ = self.gate_up_proj(x)
  103. x = self.act_fn(gate_up)
  104. x, _ = self.down_proj(x)
  105. return x
  106. class CohereAttention(nn.Module):
  107. def __init__(
  108. self,
  109. config: CohereConfig,
  110. cache_config: Optional[CacheConfig] = None,
  111. quant_config: Optional[QuantizationConfig] = None,
  112. ):
  113. super().__init__()
  114. tp_size = get_tensor_model_parallel_world_size()
  115. self.config = config
  116. self.attention_dropout = config.attention_dropout
  117. self.hidden_size = config.hidden_size
  118. self.total_num_heads = config.num_attention_heads
  119. self.num_heads = self.total_num_heads // tp_size
  120. self.head_dim = self.hidden_size // self.total_num_heads
  121. self.total_num_kv_heads = config.num_key_value_heads
  122. if self.total_num_kv_heads >= tp_size:
  123. # Number of KV heads is greater than TP size, so we partition
  124. # the KV heads across multiple tensor parallel GPUs.
  125. assert self.total_num_kv_heads % tp_size == 0
  126. else:
  127. # Number of KV heads is less than TP size, so we replicate
  128. # the KV heads across multiple tensor parallel GPUs.
  129. assert tp_size % self.total_num_kv_heads == 0
  130. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  131. self.q_size = self.num_heads * self.head_dim
  132. self.kv_size = self.num_kv_heads * self.head_dim
  133. self.scaling = self.head_dim**-0.5
  134. self.max_position_embeddings = getattr(
  135. config, "model_max_length", None) or getattr(
  136. config, "max_position_embeddings", 8192)
  137. self.rope_theta = config.rope_theta
  138. self.rope_scaling = getattr(config, "rope_scaling", None)
  139. self.use_qk_norm = getattr(config, "use_qk_norm", False)
  140. self.qkv_proj = QKVParallelLinear(
  141. self.hidden_size,
  142. self.head_dim,
  143. self.total_num_heads,
  144. self.total_num_kv_heads,
  145. bias=False,
  146. quant_config=quant_config,
  147. )
  148. self.o_proj = RowParallelLinear(
  149. self.total_num_heads * self.head_dim,
  150. self.hidden_size,
  151. bias=False,
  152. quant_config=quant_config,
  153. )
  154. self.rotary_emb = get_rope(
  155. self.head_dim,
  156. rotary_dim=self.head_dim,
  157. max_position=self.max_position_embeddings,
  158. base=self.rope_theta,
  159. rope_scaling=self.rope_scaling,
  160. is_neox_style=False,
  161. )
  162. self.attn = Attention(self.num_heads,
  163. self.head_dim,
  164. self.scaling,
  165. num_kv_heads=self.num_kv_heads,
  166. cache_config=cache_config,
  167. quant_config=quant_config)
  168. if self.use_qk_norm:
  169. self.q_norm = LayerNorm(param_shape=(self.num_heads,
  170. self.head_dim),
  171. eps=config.layer_norm_eps)
  172. self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
  173. self.head_dim),
  174. eps=config.layer_norm_eps)
  175. def _apply_qk_norm(self, q, k):
  176. q = q.view(*q.shape[:-1], -1, self.head_dim)
  177. k = k.view(*k.shape[:-1], -1, self.head_dim)
  178. q, _ = self.q_norm(q)
  179. k, _ = self.k_norm(k)
  180. q = q.view(*q.shape[:-2], -1)
  181. k = k.view(*k.shape[:-2], -1)
  182. return q, k
  183. def forward(
  184. self,
  185. positions: torch.Tensor,
  186. hidden_states: torch.Tensor,
  187. kv_cache: torch.Tensor,
  188. attn_metadata: AttentionMetadata,
  189. ) -> torch.Tensor:
  190. qkv, _ = self.qkv_proj(hidden_states)
  191. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  192. if self.use_qk_norm:
  193. q, k = self._apply_qk_norm(q, k)
  194. q, k = self.rotary_emb(positions, q, k)
  195. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  196. output, _ = self.o_proj(attn_output)
  197. return output
  198. class CohereDecoderLayer(nn.Module):
  199. def __init__(self,
  200. config: CohereConfig,
  201. cache_config: Optional[CacheConfig] = None,
  202. quant_config: Optional[QuantizationConfig] = None):
  203. super().__init__()
  204. self.hidden_size = config.hidden_size
  205. self.self_attn = CohereAttention(config,
  206. cache_config,
  207. quant_config=quant_config)
  208. self.mlp = CohereMLP(config, quant_config=quant_config)
  209. self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
  210. eps=config.layer_norm_eps)
  211. def forward(
  212. self,
  213. positions: torch.Tensor,
  214. hidden_states: torch.Tensor,
  215. kv_cache: torch.Tensor,
  216. attn_metadata: AttentionMetadata,
  217. residual: Optional[torch.Tensor],
  218. ) -> Tuple[torch.Tensor, torch.Tensor]:
  219. # Self Attention
  220. residual = hidden_states
  221. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  222. hidden_states_attention = self.self_attn(
  223. positions=positions,
  224. hidden_states=hidden_states,
  225. kv_cache=kv_cache,
  226. attn_metadata=attn_metadata,
  227. )
  228. hidden_states_mlp = self.mlp(hidden_states)
  229. # Add everything together
  230. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  231. return hidden_states, residual
  232. class CohereModel(nn.Module):
  233. def __init__(
  234. self,
  235. config: CohereConfig,
  236. cache_config: Optional[CacheConfig] = None,
  237. quant_config: Optional[QuantizationConfig] = None,
  238. lora_config: Optional[LoRAConfig] = None,
  239. ):
  240. super().__init__()
  241. self.config = config
  242. lora_vocab = (lora_config.lora_extra_vocab_size *
  243. (lora_config.max_loras or 1)) if lora_config else 0
  244. self.vocab_size = config.vocab_size + lora_vocab
  245. self.org_vocab_size = config.vocab_size
  246. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  247. config.hidden_size)
  248. self.layers = nn.ModuleList([
  249. CohereDecoderLayer(config, cache_config, quant_config=quant_config)
  250. for _ in range(config.num_hidden_layers)
  251. ])
  252. self.norm = LayerNorm(param_shape=(config.hidden_size),
  253. eps=config.layer_norm_eps)
  254. def forward(
  255. self,
  256. input_ids: torch.Tensor,
  257. positions: torch.Tensor,
  258. kv_caches: List[torch.Tensor],
  259. attn_metadata: AttentionMetadata,
  260. ) -> torch.Tensor:
  261. hidden_states = self.embed_tokens(input_ids)
  262. residual = None
  263. for i in range(len(self.layers)):
  264. layer = self.layers[i]
  265. hidden_states, residual = layer(
  266. positions,
  267. hidden_states,
  268. kv_caches[i],
  269. attn_metadata,
  270. residual,
  271. )
  272. hidden_states, _ = self.norm(hidden_states, residual)
  273. return hidden_states
  274. class CohereForCausalLM(nn.Module):
  275. packed_modules_mapping = {
  276. "qkv_proj": [
  277. "q_proj",
  278. "k_proj",
  279. "v_proj",
  280. ],
  281. "gate_up_proj": [
  282. "gate_proj",
  283. "up_proj",
  284. ],
  285. }
  286. # LoRA specific attributes
  287. supported_lora_modules = [
  288. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
  289. ]
  290. embedding_modules = {"embed_tokens": "input_embeddings"}
  291. embedding_padding_modules = []
  292. def __init__(
  293. self,
  294. config: CohereConfig,
  295. cache_config: Optional[CacheConfig] = None,
  296. quant_config: Optional[QuantizationConfig] = None,
  297. lora_config: Optional[LoRAConfig] = None,
  298. ) -> None:
  299. super().__init__()
  300. self.config = config
  301. self.unpadded_vocab_size = config.vocab_size
  302. if lora_config:
  303. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  304. self.quant_config = quant_config
  305. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  306. config.vocab_size,
  307. scale=config.logit_scale)
  308. self.model = CohereModel(config,
  309. cache_config,
  310. quant_config,
  311. lora_config=lora_config)
  312. self.sampler = Sampler()
  313. @torch.no_grad()
  314. def forward(
  315. self,
  316. input_ids: torch.Tensor,
  317. positions: torch.Tensor,
  318. kv_caches: List[torch.Tensor],
  319. attn_metadata: AttentionMetadata,
  320. intermediate_tensors: Optional[IntermediateTensors] = None,
  321. ) -> torch.Tensor:
  322. hidden_states = self.model(input_ids, positions, kv_caches,
  323. attn_metadata)
  324. return hidden_states
  325. def compute_logits(self, hidden_states: torch.Tensor,
  326. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  327. is_not_lora = hasattr(self.model.embed_tokens, 'weight')
  328. if is_not_lora:
  329. logits = self.logits_processor(self.model.embed_tokens,
  330. hidden_states, sampling_metadata)
  331. else:
  332. logits = self.logits_processor(self.model.embed_tokens.base_layer,
  333. hidden_states, sampling_metadata)
  334. return logits
  335. def sample(
  336. self,
  337. logits: torch.Tensor,
  338. sampling_metadata: SamplingMetadata,
  339. ) -> Optional[SamplerOutput]:
  340. next_tokens = self.sampler(logits, sampling_metadata)
  341. return next_tokens
  342. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  343. stacked_params_mapping = [
  344. # (param_name, shard_name, shard_id)
  345. ("qkv_proj", "q_proj", "q"),
  346. ("qkv_proj", "k_proj", "k"),
  347. ("qkv_proj", "v_proj", "v"),
  348. ("gate_up_proj", "gate_proj", 0),
  349. ("gate_up_proj", "up_proj", 1),
  350. ]
  351. params_dict = dict(self.named_parameters())
  352. loaded_params = set()
  353. for name, loaded_weight in weights:
  354. for param_name, shard_name, shard_id in stacked_params_mapping:
  355. if shard_name not in name:
  356. continue
  357. name = name.replace(shard_name, param_name)
  358. # Skip loading extra bias for GPTQ models.
  359. if name.endswith(".bias") and name not in params_dict:
  360. continue
  361. param = params_dict[name]
  362. weight_loader = param.weight_loader
  363. weight_loader(param, loaded_weight, shard_id)
  364. break
  365. else:
  366. # lm_head is not used in vllm as it is tied with embed_token.
  367. # To prevent errors, skip loading lm_head.weight.
  368. if "lm_head.weight" in name:
  369. continue
  370. # Skip loading extra bias for GPTQ models.
  371. if name.endswith(".bias") and name not in params_dict:
  372. continue
  373. param = params_dict[name]
  374. weight_loader = getattr(param, "weight_loader",
  375. default_weight_loader)
  376. weight_loader(param, loaded_weight)
  377. loaded_params.add(name)