cohere.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # coding=utf-8
  2. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. # This file is based on the LLama model definition file in transformers
  21. """PyTorch Cohere model."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
  27. from transformers import CohereConfig
  28. from aphrodite.modeling.metadata import InputMetadata
  29. from aphrodite.modeling.layers.activation import SiluAndMul
  30. from aphrodite.modeling.layers.attention import PagedAttention as Attention
  31. from aphrodite.modeling.layers.linear import (
  32. LinearMethodBase,
  33. MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear,
  36. )
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. VocabParallelEmbedding)
  41. from aphrodite.modeling.megatron.parallel_state import (
  42. get_tensor_model_parallel_world_size, )
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.modeling.hf_downloader import (
  45. default_weight_loader,
  46. hf_model_weights_iterator,
  47. )
  48. from aphrodite.common.sequence import SamplerOutput
  49. # limitations under the License.
  50. """ Cohere model configuration"""
  51. KVCache = Tuple[torch.Tensor, torch.Tensor]
  52. class LayerNorm(nn.Module):
  53. def __init__(self, hidden_size, eps=1e-5, bias=False):
  54. super().__init__()
  55. self.weight = nn.Parameter(torch.ones(hidden_size))
  56. self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
  57. self.variance_epsilon = eps
  58. def forward(self, hidden_states, residuals=None):
  59. input_dtype = hidden_states.dtype
  60. hidden_states = hidden_states.to(torch.float32)
  61. mean = hidden_states.mean(-1, keepdim=True)
  62. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  63. hidden_states = (hidden_states -
  64. mean) * torch.rsqrt(variance + self.variance_epsilon)
  65. hidden_states = self.weight.to(torch.float32) * hidden_states
  66. if self.bias is not None:
  67. hidden_states = hidden_states + self.bias.to(torch.float32)
  68. return hidden_states.to(input_dtype), residuals
  69. ALL_LAYERNORM_LAYERS.append(LayerNorm)
  70. # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
  71. class CohereMLP(nn.Module):
  72. def __init__(
  73. self,
  74. config,
  75. linear_method: Optional[LinearMethodBase] = None,
  76. ):
  77. super().__init__()
  78. self.config = config
  79. self.hidden_size = config.hidden_size
  80. self.intermediate_size = config.intermediate_size
  81. self.gate_up_proj = MergedColumnParallelLinear(
  82. self.hidden_size,
  83. [self.intermediate_size] * 2,
  84. bias=False,
  85. linear_method=linear_method,
  86. )
  87. self.down_proj = RowParallelLinear(
  88. self.intermediate_size,
  89. self.hidden_size,
  90. bias=False,
  91. linear_method=linear_method,
  92. )
  93. self.act_fn = SiluAndMul()
  94. def forward(self, x):
  95. gate_up, _ = self.gate_up_proj(x)
  96. x = self.act_fn(gate_up)
  97. x, _ = self.down_proj(x)
  98. return x
  99. class CohereAttention(nn.Module):
  100. def __init__(
  101. self,
  102. config: CohereConfig,
  103. linear_method: Optional[LinearMethodBase] = None,
  104. ):
  105. super().__init__()
  106. tp_size = get_tensor_model_parallel_world_size()
  107. self.config = config
  108. self.attention_dropout = config.attention_dropout
  109. self.hidden_size = config.hidden_size
  110. self.total_num_heads = config.num_attention_heads
  111. self.num_heads = self.total_num_heads // tp_size
  112. self.head_dim = self.hidden_size // self.total_num_heads
  113. self.total_num_kv_heads = config.num_key_value_heads
  114. if self.total_num_kv_heads >= tp_size:
  115. # Number of KV heads is greater than TP size, so we partition
  116. # the KV heads across multiple tensor parallel GPUs.
  117. assert self.total_num_kv_heads % tp_size == 0
  118. else:
  119. # Number of KV heads is less than TP size, so we replicate
  120. # the KV heads across multiple tensor parallel GPUs.
  121. assert tp_size % self.total_num_kv_heads == 0
  122. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  123. self.q_size = self.num_heads * self.head_dim
  124. self.kv_size = self.num_kv_heads * self.head_dim
  125. self.scaling = self.head_dim**-0.5
  126. self.max_position_embeddings = config.max_position_embeddings
  127. self.rope_theta = config.rope_theta
  128. self.rope_scaling = getattr(config, "rope_scaling", None)
  129. self.is_causal = True
  130. self.qkv_proj = QKVParallelLinear(
  131. self.hidden_size,
  132. self.head_dim,
  133. self.total_num_heads,
  134. self.total_num_kv_heads,
  135. bias=False,
  136. linear_method=linear_method,
  137. )
  138. self.o_proj = RowParallelLinear(
  139. self.total_num_heads * self.head_dim,
  140. self.hidden_size,
  141. bias=False,
  142. linear_method=linear_method,
  143. )
  144. self.rotary_emb = get_rope(
  145. self.head_dim,
  146. rotary_dim=self.head_dim,
  147. max_position=self.max_position_embeddings,
  148. base=self.rope_theta,
  149. rope_scaling=self.rope_scaling,
  150. is_neox_style=False,
  151. )
  152. self.attn = Attention(
  153. self.num_heads,
  154. self.head_dim,
  155. self.scaling,
  156. num_kv_heads=self.num_kv_heads,
  157. )
  158. def forward(
  159. self,
  160. positions: torch.Tensor,
  161. hidden_states: torch.Tensor,
  162. kv_cache: KVCache,
  163. input_metadata: InputMetadata,
  164. ) -> torch.Tensor:
  165. qkv, _ = self.qkv_proj(hidden_states)
  166. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  167. q, k = self.rotary_emb(positions, q, k)
  168. k_cache, v_cache = kv_cache
  169. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  170. output, _ = self.o_proj(attn_output)
  171. return output
  172. class CohereDecoderLayer(nn.Module):
  173. def __init__(self,
  174. config: CohereConfig,
  175. linear_method: Optional[LinearMethodBase] = None):
  176. super().__init__()
  177. self.hidden_size = config.hidden_size
  178. self.self_attn = CohereAttention(config, linear_method=linear_method)
  179. self.mlp = CohereMLP(config, linear_method=linear_method)
  180. self.input_layernorm = LayerNorm(config.hidden_size,
  181. eps=config.layer_norm_eps)
  182. def forward(
  183. self,
  184. positions: torch.Tensor,
  185. hidden_states: torch.Tensor,
  186. kv_cache: KVCache,
  187. input_metadata: InputMetadata,
  188. residual: Optional[torch.Tensor],
  189. ) -> Tuple[torch.Tensor, torch.Tensor]:
  190. # Self Attention
  191. residual = hidden_states
  192. hidden_states, residual = self.input_layernorm(hidden_states, residual)
  193. hidden_states_attention = self.self_attn(
  194. positions=positions,
  195. hidden_states=hidden_states,
  196. kv_cache=kv_cache,
  197. input_metadata=input_metadata,
  198. )
  199. hidden_states_mlp = self.mlp(hidden_states)
  200. # Add everything together
  201. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  202. return hidden_states, residual
  203. class CohereModel(nn.Module):
  204. """
  205. Transformer decoder consisting of *config.num_hidden_layers* layers.
  206. Each layer is a [`CohereDecoderLayer`]
  207. Args:
  208. config: CohereConfig
  209. """
  210. def __init__(
  211. self,
  212. config: CohereConfig,
  213. linear_method: Optional[LinearMethodBase] = None,
  214. ):
  215. super().__init__()
  216. self.config = config
  217. self.vocab_size = config.vocab_size
  218. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  219. config.hidden_size)
  220. self.layers = nn.ModuleList([
  221. CohereDecoderLayer(config, linear_method=linear_method)
  222. for _ in range(config.num_hidden_layers)
  223. ])
  224. self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  225. def forward(
  226. self,
  227. input_ids: torch.Tensor,
  228. positions: torch.Tensor,
  229. kv_caches: List[KVCache],
  230. input_metadata: InputMetadata,
  231. ) -> torch.Tensor:
  232. hidden_states = self.embed_tokens(input_ids)
  233. residual = None
  234. for i in range(len(self.layers)):
  235. layer = self.layers[i]
  236. hidden_states, residual = layer(
  237. positions,
  238. hidden_states,
  239. kv_caches[i],
  240. input_metadata,
  241. residual,
  242. )
  243. hidden_states, _ = self.norm(hidden_states, residual)
  244. return hidden_states
  245. class CohereForCausalLM(nn.Module):
  246. def __init__(
  247. self,
  248. config: CohereConfig,
  249. linear_method: Optional[LinearMethodBase] = None,
  250. ) -> None:
  251. super().__init__()
  252. self.config = config
  253. self.unpadded_vocab_size = config.vocab_size
  254. self.linear_method = linear_method
  255. self.model = CohereModel(config, linear_method)
  256. self.sampler = Sampler(config.vocab_size)
  257. @torch.no_grad()
  258. def forward(
  259. self,
  260. input_ids: torch.Tensor,
  261. positions: torch.Tensor,
  262. kv_caches: List[KVCache],
  263. input_metadata: InputMetadata,
  264. ) -> torch.Tensor:
  265. hidden_states = self.model(input_ids, positions, kv_caches,
  266. input_metadata)
  267. return hidden_states
  268. def sample(
  269. self,
  270. hidden_states: torch.Tensor,
  271. sampling_metadata: SamplingMetadata,
  272. ) -> Optional[SamplerOutput]:
  273. next_tokens = self.sampler(self.model.embed_tokens.weight,
  274. hidden_states, sampling_metadata)
  275. return next_tokens
  276. def load_weights(
  277. self,
  278. model_name_or_path: str,
  279. cache_dir: Optional[str] = None,
  280. load_format: str = "auto",
  281. revision: Optional[str] = None,
  282. ):
  283. stacked_params_mapping = [
  284. # (param_name, shard_name, shard_id)
  285. ("qkv_proj", "q_proj", "q"),
  286. ("qkv_proj", "k_proj", "k"),
  287. ("qkv_proj", "v_proj", "v"),
  288. ("gate_up_proj", "gate_proj", 0),
  289. ("gate_up_proj", "up_proj", 1),
  290. ]
  291. params_dict = dict(self.named_parameters())
  292. loaded_params = set()
  293. for name, loaded_weight in hf_model_weights_iterator(
  294. model_name_or_path, cache_dir, load_format, revision):
  295. for param_name, shard_name, shard_id in stacked_params_mapping:
  296. if shard_name not in name:
  297. continue
  298. name = name.replace(shard_name, param_name)
  299. param = params_dict[name]
  300. weight_loader = param.weight_loader
  301. weight_loader(param, loaded_weight, shard_id)
  302. break
  303. else:
  304. param = params_dict[name]
  305. weight_loader = getattr(param, "weight_loader",
  306. default_weight_loader)
  307. weight_loader(param, loaded_weight)
  308. loaded_params.add(name)