cohere.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn.parameter import Parameter
  27. from transformers import CohereConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.modeling.layers.activation import SiluAndMul
  30. from aphrodite.modeling.layers.linear import (
  31. ColumnParallelLinear,
  32. LinearMethodBase,
  33. MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear,
  36. )
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead,
  42. ParallelTWEHead,
  43. VocabParallelEmbedding,
  44. )
  45. from aphrodite.distributed import (
  46. get_tensor_model_parallel_rank,
  47. get_tensor_model_parallel_world_size,
  48. )
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.modeling.utils import set_weight_attrs
  51. from aphrodite.modeling.hf_downloader import (
  52. default_weight_loader,
  53. hf_model_weights_iterator,
  54. )
  55. from aphrodite.common.sequence import SamplerOutput
  56. @torch.compile
  57. def layer_norm_func(hidden_states, weight, variance_epsilon):
  58. input_dtype = hidden_states.dtype
  59. hidden_states = hidden_states.to(torch.float32)
  60. mean = hidden_states.mean(-1, keepdim=True)
  61. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  62. hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
  63. variance_epsilon)
  64. hidden_states = weight.to(torch.float32) * hidden_states
  65. return hidden_states.to(input_dtype)
  66. class LayerNorm(nn.Module):
  67. def __init__(self, hidden_size, eps=1e-5, bias=False):
  68. super().__init__()
  69. self.weight = nn.Parameter(torch.ones(hidden_size))
  70. self.variance_epsilon = eps
  71. set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
  72. def forward(self, hidden_states, residuals=None):
  73. hidden_states = layer_norm_func(hidden_states, self.weight,
  74. self.variance_epsilon)
  75. return hidden_states, residuals
  76. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  77. tp_rank = get_tensor_model_parallel_rank()
  78. shard_dim = 0 if param.dim() != 1 else None
  79. param_data = param.data
  80. if shard_dim is not None:
  81. shard_size = param_data.shape[shard_dim]
  82. start_idx = tp_rank * shard_size
  83. loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
  84. shard_size)
  85. assert param_data.shape == loaded_weight.shape
  86. param_data.copy_(loaded_weight)
  87. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  88. class CohereMLP(nn.Module):
  89. def __init__(
  90. self,
  91. config,
  92. linear_method: Optional[LinearMethodBase] = None,
  93. ):
  94. super().__init__()
  95. self.config = config
  96. self.hidden_size = config.hidden_size
  97. self.intermediate_size = config.intermediate_size
  98. if (linear_method is not None
  99. and not linear_method.quant_config.merge_weight()):
  100. self.merge_weight = False
  101. self.gate_proj = ColumnParallelLinear(
  102. self.hidden_size,
  103. self.intermediate_size,
  104. bias=False,
  105. linear_method=linear_method,
  106. )
  107. self.up_proj = ColumnParallelLinear(
  108. self.hidden_size,
  109. self.intermediate_size,
  110. bias=False,
  111. linear_method=linear_method,
  112. )
  113. else:
  114. self.merge_weight = True
  115. self.gate_up_proj = MergedColumnParallelLinear(
  116. self.hidden_size,
  117. [self.intermediate_size] * 2,
  118. bias=False,
  119. linear_method=linear_method,
  120. )
  121. self.down_proj = RowParallelLinear(
  122. self.intermediate_size,
  123. self.hidden_size,
  124. bias=False,
  125. linear_method=linear_method,
  126. )
  127. self.act_fn = SiluAndMul()
  128. def forward(self, x):
  129. if self.merge_weight:
  130. gate_up, _ = self.gate_up_proj(x)
  131. else:
  132. up, _ = self.up_proj(x)
  133. gate, _ = self.gate_proj(x)
  134. gate_up = torch.cat([gate, up], dim=-1)
  135. x = self.act_fn(gate_up)
  136. x, _ = self.down_proj(x)
  137. return x
  138. class CohereAttention(nn.Module):
  139. def __init__(
  140. self,
  141. config: CohereConfig,
  142. linear_method: Optional[LinearMethodBase] = None,
  143. ):
  144. super().__init__()
  145. tp_size = get_tensor_model_parallel_world_size()
  146. self.config = config
  147. self.attention_dropout = config.attention_dropout
  148. self.hidden_size = config.hidden_size
  149. self.total_num_heads = config.num_attention_heads
  150. self.num_heads = self.total_num_heads // tp_size
  151. self.head_dim = self.hidden_size // self.total_num_heads
  152. self.total_num_kv_heads = config.num_key_value_heads
  153. if self.total_num_kv_heads >= tp_size:
  154. # Number of KV heads is greater than TP size, so we partition
  155. # the KV heads across multiple tensor parallel GPUs.
  156. assert self.total_num_kv_heads % tp_size == 0
  157. else:
  158. # Number of KV heads is less than TP size, so we replicate
  159. # the KV heads across multiple tensor parallel GPUs.
  160. assert tp_size % self.total_num_kv_heads == 0
  161. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  162. self.q_size = self.num_heads * self.head_dim
  163. self.kv_size = self.num_kv_heads * self.head_dim
  164. self.scaling = self.head_dim**-0.5
  165. self.max_position_embeddings = getattr(
  166. config, "model_max_length", None) or getattr(
  167. config, "max_position_embeddings", 8192)
  168. self.rope_theta = config.rope_theta
  169. self.rope_scaling = getattr(config, "rope_scaling", None)
  170. self.use_qk_norm = getattr(config, "use_qk_norm", False)
  171. if (linear_method is not None
  172. and not linear_method.quant_config.merge_weight()):
  173. self.merge_weight = False
  174. self.q_proj = ColumnParallelLinear(
  175. self.hidden_size,
  176. self.total_num_heads * self.head_dim,
  177. bias=False,
  178. linear_method=linear_method,
  179. )
  180. self.k_proj = ColumnParallelLinear(
  181. self.hidden_size,
  182. self.total_num_kv_heads * self.head_dim,
  183. bias=False,
  184. linear_method=linear_method,
  185. )
  186. self.v_proj = ColumnParallelLinear(
  187. self.hidden_size,
  188. self.total_num_kv_heads * self.head_dim,
  189. bias=False,
  190. linear_method=linear_method,
  191. )
  192. else:
  193. self.merge_weight = True
  194. self.qkv_proj = QKVParallelLinear(
  195. self.hidden_size,
  196. self.head_dim,
  197. self.total_num_heads,
  198. self.total_num_kv_heads,
  199. bias=False,
  200. linear_method=linear_method,
  201. )
  202. self.o_proj = RowParallelLinear(
  203. self.total_num_heads * self.head_dim,
  204. self.hidden_size,
  205. bias=False,
  206. linear_method=linear_method,
  207. )
  208. self.rotary_emb = get_rope(
  209. self.head_dim,
  210. rotary_dim=self.head_dim,
  211. max_position=self.max_position_embeddings,
  212. base=self.rope_theta,
  213. rope_scaling=self.rope_scaling,
  214. is_neox_style=False,
  215. )
  216. self.attn = Attention(
  217. self.num_heads,
  218. self.head_dim,
  219. self.scaling,
  220. num_kv_heads=self.num_kv_heads,
  221. )
  222. if self.use_qk_norm:
  223. self.q_norm = LayerNorm(hidden_size=(self.num_heads,
  224. self.head_dim),
  225. eps=config.layer_norm_eps)
  226. self.k_norm = LayerNorm(hidden_size=(self.num_kv_heads,
  227. self.head_dim),
  228. eps=config.layer_norm_eps)
  229. def _apply_qk_norm(self, q, k):
  230. q = q.view(*q.shape[:-1], -1, self.head_dim)
  231. k = k.view(*k.shape[:-1], -1, self.head_dim)
  232. q, _ = self.q_norm(q)
  233. k, _ = self.k_norm(k)
  234. q = q.view(*q.shape[:-2], -1)
  235. k = k.view(*k.shape[:-2], -1)
  236. return q, k
  237. def forward(
  238. self,
  239. positions: torch.Tensor,
  240. hidden_states: torch.Tensor,
  241. kv_cache: torch.Tensor,
  242. attn_metadata: AttentionMetadata,
  243. ) -> torch.Tensor:
  244. if self.merge_weight:
  245. qkv, _ = self.qkv_proj(hidden_states)
  246. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  247. dim=-1)
  248. else:
  249. q, _ = self.q_proj(hidden_states)
  250. k, _ = self.k_proj(hidden_states)
  251. v, _ = self.v_proj(hidden_states)
  252. if self.use_qk_norm:
  253. q, k = self._apply_qk_norm(q, k)
  254. q, k = self.rotary_emb(positions, q, k)
  255. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  256. output, _ = self.o_proj(attn_output)
  257. return output
  258. class CohereDecoderLayer(nn.Module):
  259. def __init__(
  260. self,
  261. config: CohereConfig,
  262. linear_method: Optional[LinearMethodBase] = None,
  263. ):
  264. super().__init__()
  265. self.hidden_size = config.hidden_size
  266. self.self_attn = CohereAttention(config, linear_method=linear_method)
  267. self.mlp = CohereMLP(config, linear_method=linear_method)
  268. self.input_layernorm = LayerNorm(config.hidden_size,
  269. eps=config.layer_norm_eps)
  270. def forward(
  271. self,
  272. positions: torch.Tensor,
  273. hidden_states: torch.Tensor,
  274. kv_cache: torch.Tensor,
  275. attn_metadata: AttentionMetadata,
  276. residual: Optional[torch.Tensor],
  277. ) -> Tuple[torch.Tensor, torch.Tensor]:
  278. # Self Attention
  279. residual = hidden_states
  280. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  281. hidden_states_attention = self.self_attn(
  282. positions=positions,
  283. hidden_states=hidden_states,
  284. kv_cache=kv_cache,
  285. attn_metadata=attn_metadata,
  286. )
  287. hidden_states_mlp = self.mlp(hidden_states)
  288. # Add everything together
  289. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  290. return hidden_states, residual
  291. class CohereModel(nn.Module):
  292. def __init__(
  293. self,
  294. config: CohereConfig,
  295. linear_method: Optional[LinearMethodBase] = None,
  296. ):
  297. super().__init__()
  298. self.config = config
  299. self.vocab_size = config.vocab_size
  300. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  301. config.hidden_size,
  302. linear_method=linear_method)
  303. self.layers = nn.ModuleList([
  304. CohereDecoderLayer(config, linear_method=linear_method)
  305. for _ in range(config.num_hidden_layers)
  306. ])
  307. self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  308. def forward(
  309. self,
  310. input_ids: torch.Tensor,
  311. positions: torch.Tensor,
  312. kv_caches: List[torch.Tensor],
  313. attn_metadata: AttentionMetadata,
  314. ) -> torch.Tensor:
  315. hidden_states = self.embed_tokens(input_ids)
  316. residual = None
  317. for i in range(len(self.layers)):
  318. layer = self.layers[i]
  319. hidden_states, residual = layer(
  320. positions,
  321. hidden_states,
  322. kv_caches[i],
  323. attn_metadata,
  324. residual,
  325. )
  326. hidden_states, _ = self.norm(hidden_states, residual)
  327. return hidden_states
  328. class CohereForCausalLM(nn.Module):
  329. def __init__(
  330. self,
  331. config: CohereConfig,
  332. linear_method: Optional[LinearMethodBase] = None,
  333. ) -> None:
  334. super().__init__()
  335. self.config = config
  336. self.linear_method = linear_method
  337. self.model = CohereModel(config, linear_method)
  338. if self.config.tie_word_embeddings:
  339. self.lm_head = ParallelTWEHead(self.model.embed_tokens)
  340. else:
  341. self.lm_head = ParallelLMHead(config.vocab_size,
  342. config.hidden_size,
  343. linear_method=linear_method)
  344. self.logits_processor = LogitsProcessor(config.vocab_size,
  345. config.tokenizer_vocab_size,
  346. scale=config.logit_scale)
  347. self.sampler = Sampler()
  348. @torch.no_grad()
  349. def forward(
  350. self,
  351. input_ids: torch.Tensor,
  352. positions: torch.Tensor,
  353. kv_caches: List[torch.Tensor],
  354. attn_metadata: AttentionMetadata,
  355. ) -> torch.Tensor:
  356. hidden_states = self.model(input_ids, positions, kv_caches,
  357. attn_metadata)
  358. return hidden_states
  359. def compute_logits(self, hidden_states: torch.Tensor,
  360. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  361. logits = self.logits_processor(self.lm_head, hidden_states,
  362. sampling_metadata)
  363. return logits
  364. def sample(
  365. self,
  366. logits: torch.Tensor,
  367. sampling_metadata: SamplingMetadata,
  368. ) -> Optional[SamplerOutput]:
  369. next_tokens = self.sampler(logits, sampling_metadata)
  370. return next_tokens
  371. def load_weights(
  372. self,
  373. model_name_or_path: str,
  374. cache_dir: Optional[str] = None,
  375. load_format: str = "auto",
  376. revision: Optional[str] = None,
  377. ):
  378. stacked_params_mapping = [
  379. # (param_name, shard_name, shard_id)
  380. ("qkv_proj", "q_proj", "q"),
  381. ("qkv_proj", "k_proj", "k"),
  382. ("qkv_proj", "v_proj", "v"),
  383. ("gate_up_proj", "gate_proj", 0),
  384. ("gate_up_proj", "up_proj", 1),
  385. ]
  386. if (self.linear_method is not None
  387. and not self.linear_method.quant_config.merge_weight()):
  388. stacked_params_mapping = []
  389. params_dict = dict(self.named_parameters())
  390. loaded_params = set()
  391. for name, loaded_weight in hf_model_weights_iterator(
  392. model_name_or_path, cache_dir, load_format, revision,
  393. self.config):
  394. for param_name, shard_name, shard_id in stacked_params_mapping:
  395. if shard_name not in name:
  396. continue
  397. name = name.replace(shard_name, param_name)
  398. # Skip loading extra bias for GPTQ models.
  399. if name.endswith(".bias") and name not in params_dict:
  400. continue
  401. param = params_dict[name]
  402. weight_loader = param.weight_loader
  403. weight_loader(param, loaded_weight, shard_id)
  404. break
  405. else:
  406. # lm_head is not used as it is tied with embed_token.
  407. # To prevent errors, skip loading lm_head.
  408. if self.config.tie_word_embeddings and "lm_head" in name:
  409. continue
  410. # Skip loading extra bias for GPTQ models.
  411. if name.endswith(".bias") and name not in params_dict:
  412. continue
  413. param = params_dict[name]
  414. weight_loader = getattr(param, "weight_loader",
  415. default_weight_loader)
  416. weight_loader(param, loaded_weight)
  417. loaded_params.add(name)