1
0

cohere.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn.parameter import Parameter
  27. from transformers import CohereConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.modeling.layers.activation import SiluAndMul
  30. from aphrodite.modeling.layers.linear import (
  31. ColumnParallelLinear,
  32. LinearMethodBase,
  33. MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear,
  36. )
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. VocabParallelEmbedding, )
  42. from aphrodite.distributed import (
  43. get_tensor_model_parallel_rank,
  44. get_tensor_model_parallel_world_size,
  45. )
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.utils import set_weight_attrs
  48. from aphrodite.modeling.hf_downloader import (
  49. default_weight_loader,
  50. hf_model_weights_iterator,
  51. )
  52. from aphrodite.common.sequence import SamplerOutput
  53. @torch.compile
  54. def layer_norm_func(hidden_states, weight, variance_epsilon):
  55. input_dtype = hidden_states.dtype
  56. hidden_states = hidden_states.to(torch.float32)
  57. mean = hidden_states.mean(-1, keepdim=True)
  58. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  59. hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
  60. variance_epsilon)
  61. hidden_states = weight.to(torch.float32) * hidden_states
  62. return hidden_states.to(input_dtype)
  63. class LayerNorm(nn.Module):
  64. def __init__(self, hidden_size, eps=1e-5, bias=False):
  65. super().__init__()
  66. self.weight = nn.Parameter(torch.ones(hidden_size))
  67. self.variance_epsilon = eps
  68. set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
  69. def forward(self, hidden_states, residuals=None):
  70. hidden_states = layer_norm_func(hidden_states, self.weight,
  71. self.variance_epsilon)
  72. return hidden_states, residuals
  73. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  74. tp_rank = get_tensor_model_parallel_rank()
  75. shard_dim = 0 if param.dim() != 1 else None
  76. param_data = param.data
  77. if shard_dim is not None:
  78. shard_size = param_data.shape[shard_dim]
  79. start_idx = tp_rank * shard_size
  80. loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
  81. shard_size)
  82. assert param_data.shape == loaded_weight.shape
  83. param_data.copy_(loaded_weight)
  84. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  85. class CohereMLP(nn.Module):
  86. def __init__(
  87. self,
  88. config,
  89. linear_method: Optional[LinearMethodBase] = None,
  90. ):
  91. super().__init__()
  92. self.config = config
  93. self.hidden_size = config.hidden_size
  94. self.intermediate_size = config.intermediate_size
  95. if (linear_method is not None
  96. and not linear_method.quant_config.merge_weight()):
  97. self.merge_weight = False
  98. self.gate_proj = ColumnParallelLinear(
  99. self.hidden_size,
  100. self.intermediate_size,
  101. bias=False,
  102. linear_method=linear_method,
  103. )
  104. self.up_proj = ColumnParallelLinear(
  105. self.hidden_size,
  106. self.intermediate_size,
  107. bias=False,
  108. linear_method=linear_method,
  109. )
  110. else:
  111. self.merge_weight = True
  112. self.gate_up_proj = MergedColumnParallelLinear(
  113. self.hidden_size,
  114. [self.intermediate_size] * 2,
  115. bias=False,
  116. linear_method=linear_method,
  117. )
  118. self.down_proj = RowParallelLinear(
  119. self.intermediate_size,
  120. self.hidden_size,
  121. bias=False,
  122. linear_method=linear_method,
  123. )
  124. self.act_fn = SiluAndMul()
  125. def forward(self, x):
  126. if self.merge_weight:
  127. gate_up, _ = self.gate_up_proj(x)
  128. else:
  129. up, _ = self.up_proj(x)
  130. gate, _ = self.gate_proj(x)
  131. gate_up = torch.cat([gate, up], dim=-1)
  132. x = self.act_fn(gate_up)
  133. x, _ = self.down_proj(x)
  134. return x
  135. class CohereAttention(nn.Module):
  136. def __init__(
  137. self,
  138. config: CohereConfig,
  139. linear_method: Optional[LinearMethodBase] = None,
  140. ):
  141. super().__init__()
  142. tp_size = get_tensor_model_parallel_world_size()
  143. self.config = config
  144. self.attention_dropout = config.attention_dropout
  145. self.hidden_size = config.hidden_size
  146. self.total_num_heads = config.num_attention_heads
  147. self.num_heads = self.total_num_heads // tp_size
  148. self.head_dim = self.hidden_size // self.total_num_heads
  149. self.total_num_kv_heads = config.num_key_value_heads
  150. if self.total_num_kv_heads >= tp_size:
  151. # Number of KV heads is greater than TP size, so we partition
  152. # the KV heads across multiple tensor parallel GPUs.
  153. assert self.total_num_kv_heads % tp_size == 0
  154. else:
  155. # Number of KV heads is less than TP size, so we replicate
  156. # the KV heads across multiple tensor parallel GPUs.
  157. assert tp_size % self.total_num_kv_heads == 0
  158. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  159. self.q_size = self.num_heads * self.head_dim
  160. self.kv_size = self.num_kv_heads * self.head_dim
  161. self.scaling = self.head_dim**-0.5
  162. self.max_position_embeddings = getattr(
  163. config, "model_max_length", None) or getattr(
  164. config, "max_position_embeddings", 8192)
  165. self.rope_theta = config.rope_theta
  166. self.rope_scaling = getattr(config, "rope_scaling", None)
  167. self.use_qk_norm = getattr(config, "use_qk_norm", False)
  168. if (linear_method is not None
  169. and not linear_method.quant_config.merge_weight()):
  170. self.merge_weight = False
  171. self.q_proj = ColumnParallelLinear(
  172. self.hidden_size,
  173. self.total_num_heads * self.head_dim,
  174. bias=False,
  175. linear_method=linear_method,
  176. )
  177. self.k_proj = ColumnParallelLinear(
  178. self.hidden_size,
  179. self.total_num_kv_heads * self.head_dim,
  180. bias=False,
  181. linear_method=linear_method,
  182. )
  183. self.v_proj = ColumnParallelLinear(
  184. self.hidden_size,
  185. self.total_num_kv_heads * self.head_dim,
  186. bias=False,
  187. linear_method=linear_method,
  188. )
  189. else:
  190. self.merge_weight = True
  191. self.qkv_proj = QKVParallelLinear(
  192. self.hidden_size,
  193. self.head_dim,
  194. self.total_num_heads,
  195. self.total_num_kv_heads,
  196. bias=False,
  197. linear_method=linear_method,
  198. )
  199. self.o_proj = RowParallelLinear(
  200. self.total_num_heads * self.head_dim,
  201. self.hidden_size,
  202. bias=False,
  203. linear_method=linear_method,
  204. )
  205. self.rotary_emb = get_rope(
  206. self.head_dim,
  207. rotary_dim=self.head_dim,
  208. max_position=self.max_position_embeddings,
  209. base=self.rope_theta,
  210. rope_scaling=self.rope_scaling,
  211. is_neox_style=False,
  212. )
  213. self.attn = Attention(
  214. self.num_heads,
  215. self.head_dim,
  216. self.scaling,
  217. num_kv_heads=self.num_kv_heads,
  218. )
  219. if self.use_qk_norm:
  220. self.q_norm = LayerNorm(hidden_size=(self.num_heads,
  221. self.head_dim),
  222. eps=config.layer_norm_eps)
  223. self.k_norm = LayerNorm(hidden_size=(self.num_kv_heads,
  224. self.head_dim),
  225. eps=config.layer_norm_eps)
  226. def _apply_qk_norm(self, q, k):
  227. q = q.view(*q.shape[:-1], -1, self.head_dim)
  228. k = k.view(*k.shape[:-1], -1, self.head_dim)
  229. q, _ = self.q_norm(q)
  230. k, _ = self.k_norm(k)
  231. q = q.view(*q.shape[:-2], -1)
  232. k = k.view(*k.shape[:-2], -1)
  233. return q, k
  234. def forward(
  235. self,
  236. positions: torch.Tensor,
  237. hidden_states: torch.Tensor,
  238. kv_cache: torch.Tensor,
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. if self.merge_weight:
  242. qkv, _ = self.qkv_proj(hidden_states)
  243. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  244. dim=-1)
  245. else:
  246. q, _ = self.q_proj(hidden_states)
  247. k, _ = self.k_proj(hidden_states)
  248. v, _ = self.v_proj(hidden_states)
  249. if self.use_qk_norm:
  250. q, k = self._apply_qk_norm(q, k)
  251. q, k = self.rotary_emb(positions, q, k)
  252. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  253. output, _ = self.o_proj(attn_output)
  254. return output
  255. class TieWordEmbeddingHead(nn.Module):
  256. def __init__(self):
  257. super().__init__()
  258. self.embedding = None
  259. def forward(self, hidden_states):
  260. return torch.matmul(hidden_states, self.embedding.t())
  261. class CohereDecoderLayer(nn.Module):
  262. def __init__(
  263. self,
  264. config: CohereConfig,
  265. linear_method: Optional[LinearMethodBase] = None,
  266. ):
  267. super().__init__()
  268. self.hidden_size = config.hidden_size
  269. self.self_attn = CohereAttention(config, linear_method=linear_method)
  270. self.mlp = CohereMLP(config, linear_method=linear_method)
  271. self.input_layernorm = LayerNorm(config.hidden_size,
  272. eps=config.layer_norm_eps)
  273. def forward(
  274. self,
  275. positions: torch.Tensor,
  276. hidden_states: torch.Tensor,
  277. kv_cache: torch.Tensor,
  278. attn_metadata: AttentionMetadata,
  279. residual: Optional[torch.Tensor],
  280. ) -> Tuple[torch.Tensor, torch.Tensor]:
  281. # Self Attention
  282. residual = hidden_states
  283. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  284. hidden_states_attention = self.self_attn(
  285. positions=positions,
  286. hidden_states=hidden_states,
  287. kv_cache=kv_cache,
  288. attn_metadata=attn_metadata,
  289. )
  290. hidden_states_mlp = self.mlp(hidden_states)
  291. # Add everything together
  292. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  293. return hidden_states, residual
  294. class CohereModel(nn.Module):
  295. def __init__(
  296. self,
  297. config: CohereConfig,
  298. linear_method: Optional[LinearMethodBase] = None,
  299. ):
  300. super().__init__()
  301. self.config = config
  302. self.vocab_size = config.vocab_size
  303. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  304. config.hidden_size,
  305. linear_method=linear_method)
  306. self.layers = nn.ModuleList([
  307. CohereDecoderLayer(config, linear_method=linear_method)
  308. for _ in range(config.num_hidden_layers)
  309. ])
  310. self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  311. def forward(
  312. self,
  313. input_ids: torch.Tensor,
  314. positions: torch.Tensor,
  315. kv_caches: List[torch.Tensor],
  316. attn_metadata: AttentionMetadata,
  317. ) -> torch.Tensor:
  318. hidden_states = self.embed_tokens(input_ids)
  319. residual = None
  320. for i in range(len(self.layers)):
  321. layer = self.layers[i]
  322. hidden_states, residual = layer(
  323. positions,
  324. hidden_states,
  325. kv_caches[i],
  326. attn_metadata,
  327. residual,
  328. )
  329. hidden_states, _ = self.norm(hidden_states, residual)
  330. return hidden_states
  331. class CohereForCausalLM(nn.Module):
  332. def __init__(
  333. self,
  334. config: CohereConfig,
  335. linear_method: Optional[LinearMethodBase] = None,
  336. ) -> None:
  337. super().__init__()
  338. self.config = config
  339. self.linear_method = linear_method
  340. self.lm_head = TieWordEmbeddingHead()
  341. self.logits_processor = LogitsProcessor(config.vocab_size,
  342. scale=config.logit_scale)
  343. self.model = CohereModel(config, linear_method)
  344. self.sampler = Sampler()
  345. @torch.no_grad()
  346. def forward(
  347. self,
  348. input_ids: torch.Tensor,
  349. positions: torch.Tensor,
  350. kv_caches: List[torch.Tensor],
  351. attn_metadata: AttentionMetadata,
  352. ) -> torch.Tensor:
  353. hidden_states = self.model(input_ids, positions, kv_caches,
  354. attn_metadata)
  355. return hidden_states
  356. def compute_logits(self, hidden_states: torch.Tensor,
  357. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  358. logits = self.logits_processor(self.lm_head, hidden_states,
  359. sampling_metadata)
  360. return logits
  361. def sample(
  362. self,
  363. logits: torch.Tensor,
  364. sampling_metadata: SamplingMetadata,
  365. ) -> Optional[SamplerOutput]:
  366. next_tokens = self.sampler(logits, sampling_metadata)
  367. return next_tokens
  368. def load_weights(
  369. self,
  370. model_name_or_path: str,
  371. cache_dir: Optional[str] = None,
  372. load_format: str = "auto",
  373. revision: Optional[str] = None,
  374. ):
  375. stacked_params_mapping = [
  376. # (param_name, shard_name, shard_id)
  377. ("qkv_proj", "q_proj", "q"),
  378. ("qkv_proj", "k_proj", "k"),
  379. ("qkv_proj", "v_proj", "v"),
  380. ("gate_up_proj", "gate_proj", 0),
  381. ("gate_up_proj", "up_proj", 1),
  382. ]
  383. if (self.linear_method is not None
  384. and not self.linear_method.quant_config.merge_weight()):
  385. stacked_params_mapping = []
  386. params_dict = dict(self.named_parameters())
  387. loaded_params = set()
  388. for name, loaded_weight in hf_model_weights_iterator(
  389. model_name_or_path, cache_dir, load_format, revision,
  390. self.config):
  391. for param_name, shard_name, shard_id in stacked_params_mapping:
  392. if shard_name not in name:
  393. continue
  394. name = name.replace(shard_name, param_name)
  395. # Skip loading extra bias for GPTQ models.
  396. if name.endswith(".bias") and name not in params_dict:
  397. continue
  398. param = params_dict[name]
  399. weight_loader = param.weight_loader
  400. weight_loader(param, loaded_weight, shard_id)
  401. break
  402. else:
  403. # lm_head is not used as it is tied with embed_token.
  404. # To prevent errors, skip loading lm_head.
  405. if "lm_head" in name:
  406. continue
  407. # Skip loading extra bias for GPTQ models.
  408. if name.endswith(".bias") and name not in params_dict:
  409. continue
  410. param = params_dict[name]
  411. weight_loader = getattr(param, "weight_loader",
  412. default_weight_loader)
  413. weight_loader(param, loaded_weight)
  414. loaded_params.add(name)
  415. self.lm_head.embedding = self.model.embed_tokens.weight