commandr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from transformers import CohereConfig
  27. from aphrodite.attention import Attention, AttentionMetadata
  28. from aphrodite.common.config import CacheConfig, LoRAConfig
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.distributed import get_tensor_model_parallel_world_size
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. VocabParallelEmbedding)
  40. from aphrodite.modeling.model_loader.weight_utils import (
  41. default_weight_loader, row_parallel_weight_loader)
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.modeling.utils import set_weight_attrs
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. @torch.compile
  46. def layer_norm_func(hidden_states, weight, variance_epsilon):
  47. input_dtype = hidden_states.dtype
  48. hidden_states = hidden_states.to(torch.float32)
  49. mean = hidden_states.mean(-1, keepdim=True)
  50. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  51. hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
  52. variance_epsilon)
  53. hidden_states = weight.to(torch.float32) * hidden_states
  54. return hidden_states.to(input_dtype)
  55. class LayerNorm(nn.Module):
  56. def __init__(self, param_shape=None, eps=1e-5):
  57. super().__init__()
  58. self.weight = nn.Parameter(torch.ones(param_shape))
  59. self.variance_epsilon = eps
  60. set_weight_attrs(self.weight,
  61. {"weight_loader": row_parallel_weight_loader})
  62. def forward(self, hidden_states, residuals=None):
  63. hidden_states = layer_norm_func(hidden_states, self.weight,
  64. self.variance_epsilon)
  65. return hidden_states, residuals
  66. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  67. class CohereMLP(nn.Module):
  68. def __init__(
  69. self,
  70. config,
  71. quant_config: Optional[QuantizationConfig] = None,
  72. ):
  73. super().__init__()
  74. self.config = config
  75. self.hidden_size = config.hidden_size
  76. self.intermediate_size = config.intermediate_size
  77. self.gate_up_proj = MergedColumnParallelLinear(
  78. self.hidden_size,
  79. [self.intermediate_size] * 2,
  80. bias=False,
  81. quant_config=quant_config,
  82. )
  83. self.down_proj = RowParallelLinear(
  84. self.intermediate_size,
  85. self.hidden_size,
  86. bias=False,
  87. quant_config=quant_config,
  88. )
  89. self.act_fn = SiluAndMul()
  90. def forward(self, x):
  91. gate_up, _ = self.gate_up_proj(x)
  92. x = self.act_fn(gate_up)
  93. x, _ = self.down_proj(x)
  94. return x
  95. class CohereAttention(nn.Module):
  96. def __init__(
  97. self,
  98. config: CohereConfig,
  99. cache_config: Optional[CacheConfig] = None,
  100. quant_config: Optional[QuantizationConfig] = None,
  101. ):
  102. super().__init__()
  103. tp_size = get_tensor_model_parallel_world_size()
  104. self.config = config
  105. self.attention_dropout = config.attention_dropout
  106. self.hidden_size = config.hidden_size
  107. self.total_num_heads = config.num_attention_heads
  108. self.num_heads = self.total_num_heads // tp_size
  109. self.head_dim = self.hidden_size // self.total_num_heads
  110. self.total_num_kv_heads = config.num_key_value_heads
  111. if self.total_num_kv_heads >= tp_size:
  112. # Number of KV heads is greater than TP size, so we partition
  113. # the KV heads across multiple tensor parallel GPUs.
  114. assert self.total_num_kv_heads % tp_size == 0
  115. else:
  116. # Number of KV heads is less than TP size, so we replicate
  117. # the KV heads across multiple tensor parallel GPUs.
  118. assert tp_size % self.total_num_kv_heads == 0
  119. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  120. self.q_size = self.num_heads * self.head_dim
  121. self.kv_size = self.num_kv_heads * self.head_dim
  122. self.scaling = self.head_dim**-0.5
  123. self.max_position_embeddings = getattr(
  124. config, "model_max_length", None) or getattr(
  125. config, "max_position_embeddings", 8192)
  126. self.rope_theta = config.rope_theta
  127. self.rope_scaling = getattr(config, "rope_scaling", None)
  128. self.use_qk_norm = getattr(config, "use_qk_norm", False)
  129. self.qkv_proj = QKVParallelLinear(
  130. self.hidden_size,
  131. self.head_dim,
  132. self.total_num_heads,
  133. self.total_num_kv_heads,
  134. bias=False,
  135. quant_config=quant_config,
  136. )
  137. self.o_proj = RowParallelLinear(
  138. self.total_num_heads * self.head_dim,
  139. self.hidden_size,
  140. bias=False,
  141. quant_config=quant_config,
  142. )
  143. self.rotary_emb = get_rope(
  144. self.head_dim,
  145. rotary_dim=self.head_dim,
  146. max_position=self.max_position_embeddings,
  147. base=self.rope_theta,
  148. rope_scaling=self.rope_scaling,
  149. is_neox_style=False,
  150. )
  151. self.attn = Attention(self.num_heads,
  152. self.head_dim,
  153. self.scaling,
  154. num_kv_heads=self.num_kv_heads,
  155. cache_config=cache_config,
  156. quant_config=quant_config)
  157. if self.use_qk_norm:
  158. self.q_norm = LayerNorm(param_shape=(self.num_heads,
  159. self.head_dim),
  160. eps=config.layer_norm_eps)
  161. self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
  162. self.head_dim),
  163. eps=config.layer_norm_eps)
  164. def _apply_qk_norm(self, q, k):
  165. q = q.view(*q.shape[:-1], -1, self.head_dim)
  166. k = k.view(*k.shape[:-1], -1, self.head_dim)
  167. q, _ = self.q_norm(q)
  168. k, _ = self.k_norm(k)
  169. q = q.view(*q.shape[:-2], -1)
  170. k = k.view(*k.shape[:-2], -1)
  171. return q, k
  172. def forward(
  173. self,
  174. positions: torch.Tensor,
  175. hidden_states: torch.Tensor,
  176. kv_cache: torch.Tensor,
  177. attn_metadata: AttentionMetadata,
  178. ) -> torch.Tensor:
  179. qkv, _ = self.qkv_proj(hidden_states)
  180. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  181. if self.use_qk_norm:
  182. q, k = self._apply_qk_norm(q, k)
  183. q, k = self.rotary_emb(positions, q, k)
  184. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  185. output, _ = self.o_proj(attn_output)
  186. return output
  187. class CohereDecoderLayer(nn.Module):
  188. def __init__(self,
  189. config: CohereConfig,
  190. cache_config: Optional[CacheConfig] = None,
  191. quant_config: Optional[QuantizationConfig] = None):
  192. super().__init__()
  193. self.hidden_size = config.hidden_size
  194. self.self_attn = CohereAttention(config,
  195. cache_config,
  196. quant_config=quant_config)
  197. self.mlp = CohereMLP(config, quant_config=quant_config)
  198. self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
  199. eps=config.layer_norm_eps)
  200. def forward(
  201. self,
  202. positions: torch.Tensor,
  203. hidden_states: torch.Tensor,
  204. kv_cache: torch.Tensor,
  205. attn_metadata: AttentionMetadata,
  206. residual: Optional[torch.Tensor],
  207. ) -> Tuple[torch.Tensor, torch.Tensor]:
  208. # Self Attention
  209. residual = hidden_states
  210. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  211. hidden_states_attention = self.self_attn(
  212. positions=positions,
  213. hidden_states=hidden_states,
  214. kv_cache=kv_cache,
  215. attn_metadata=attn_metadata,
  216. )
  217. hidden_states_mlp = self.mlp(hidden_states)
  218. # Add everything together
  219. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  220. return hidden_states, residual
  221. class CohereModel(nn.Module):
  222. def __init__(
  223. self,
  224. config: CohereConfig,
  225. cache_config: Optional[CacheConfig] = None,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. lora_config: Optional[LoRAConfig] = None,
  228. ):
  229. super().__init__()
  230. self.config = config
  231. lora_vocab = (lora_config.lora_extra_vocab_size *
  232. (lora_config.max_loras or 1)) if lora_config else 0
  233. self.vocab_size = config.vocab_size + lora_vocab
  234. self.org_vocab_size = config.vocab_size
  235. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  236. config.hidden_size)
  237. self.layers = nn.ModuleList([
  238. CohereDecoderLayer(config, cache_config, quant_config=quant_config)
  239. for _ in range(config.num_hidden_layers)
  240. ])
  241. self.norm = LayerNorm(param_shape=(config.hidden_size),
  242. eps=config.layer_norm_eps)
  243. def forward(
  244. self,
  245. input_ids: torch.Tensor,
  246. positions: torch.Tensor,
  247. kv_caches: List[torch.Tensor],
  248. attn_metadata: AttentionMetadata,
  249. ) -> torch.Tensor:
  250. hidden_states = self.embed_tokens(input_ids)
  251. residual = None
  252. for i in range(len(self.layers)):
  253. layer = self.layers[i]
  254. hidden_states, residual = layer(
  255. positions,
  256. hidden_states,
  257. kv_caches[i],
  258. attn_metadata,
  259. residual,
  260. )
  261. hidden_states, _ = self.norm(hidden_states, residual)
  262. return hidden_states
  263. class CohereForCausalLM(nn.Module):
  264. packed_modules_mapping = {
  265. "qkv_proj": [
  266. "q_proj",
  267. "k_proj",
  268. "v_proj",
  269. ],
  270. "gate_up_proj": [
  271. "gate_proj",
  272. "up_proj",
  273. ],
  274. }
  275. # LoRA specific attributes
  276. supported_lora_modules = [
  277. "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
  278. ]
  279. embedding_modules = {"embed_tokens": "input_embeddings"}
  280. embedding_padding_modules = []
  281. def __init__(
  282. self,
  283. config: CohereConfig,
  284. cache_config: Optional[CacheConfig] = None,
  285. quant_config: Optional[QuantizationConfig] = None,
  286. lora_config: Optional[LoRAConfig] = None,
  287. ) -> None:
  288. super().__init__()
  289. self.config = config
  290. self.unpadded_vocab_size = config.vocab_size
  291. if lora_config:
  292. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  293. self.quant_config = quant_config
  294. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  295. config.vocab_size,
  296. scale=config.logit_scale)
  297. self.model = CohereModel(config,
  298. cache_config,
  299. quant_config,
  300. lora_config=lora_config)
  301. self.sampler = Sampler()
  302. @torch.no_grad()
  303. def forward(
  304. self,
  305. input_ids: torch.Tensor,
  306. positions: torch.Tensor,
  307. kv_caches: List[torch.Tensor],
  308. attn_metadata: AttentionMetadata,
  309. intermediate_tensors: Optional[IntermediateTensors] = None,
  310. ) -> torch.Tensor:
  311. hidden_states = self.model(input_ids, positions, kv_caches,
  312. attn_metadata)
  313. return hidden_states
  314. def compute_logits(
  315. self,
  316. hidden_states: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[torch.Tensor]:
  319. is_not_lora = hasattr(self.model.embed_tokens, 'weight')
  320. if is_not_lora:
  321. logits = self.logits_processor(self.model.embed_tokens,
  322. hidden_states, sampling_metadata)
  323. else:
  324. logits = self.logits_processor(self.model.embed_tokens.base_layer,
  325. hidden_states, sampling_metadata)
  326. return logits
  327. def sample(
  328. self,
  329. logits: torch.Tensor,
  330. sampling_metadata: SamplingMetadata,
  331. ) -> Optional[SamplerOutput]:
  332. next_tokens = self.sampler(logits, sampling_metadata)
  333. return next_tokens
  334. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  335. stacked_params_mapping = [
  336. # (param_name, shard_name, shard_id)
  337. ("qkv_proj", "q_proj", "q"),
  338. ("qkv_proj", "k_proj", "k"),
  339. ("qkv_proj", "v_proj", "v"),
  340. ("gate_up_proj", "gate_proj", 0),
  341. ("gate_up_proj", "up_proj", 1),
  342. ]
  343. params_dict = dict(self.named_parameters())
  344. loaded_params = set()
  345. for name, loaded_weight in weights:
  346. for param_name, shard_name, shard_id in stacked_params_mapping:
  347. if shard_name not in name:
  348. continue
  349. name = name.replace(shard_name, param_name)
  350. # Skip loading extra bias for GPTQ models.
  351. if name.endswith(".bias") and name not in params_dict:
  352. continue
  353. param = params_dict[name]
  354. weight_loader = param.weight_loader
  355. weight_loader(param, loaded_weight, shard_id)
  356. break
  357. else:
  358. # lm_head is not used in vllm as it is tied with embed_token.
  359. # To prevent errors, skip loading lm_head.weight.
  360. if "lm_head.weight" in name:
  361. continue
  362. # Skip loading extra bias for GPTQ models.
  363. if name.endswith(".bias") and name not in params_dict:
  364. continue
  365. param = params_dict[name]
  366. weight_loader = getattr(param, "weight_loader",
  367. default_weight_loader)
  368. weight_loader(param, loaded_weight)
  369. loaded_params.add(name)