gemma.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # coding=utf-8
  2. # Copyright 2023 The PygmalionAI team.
  3. # Copyright 2023 The vLLM team.
  4. # Copyright (c) Google Inc.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Inference-only Gemma model compatible with HuggingFace weights."""
  18. from typing import List, Optional, Tuple
  19. import torch
  20. from torch import nn
  21. from transformers import GemmaConfig
  22. from aphrodite.modeling.metadata import InputMetadata
  23. from aphrodite.modeling.layers.activation import GeluAndMul
  24. from aphrodite.modeling.layers.attention import PagedAttention
  25. from aphrodite.modeling.layers.layernorm import RMSNorm
  26. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  27. MergedColumnParallelLinear,
  28. QKVParallelLinear,
  29. RowParallelLinear,
  30. ColumnParallelLinear)
  31. from aphrodite.modeling.layers.rotary_embedding import get_rope
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  34. VocabParallelEmbedding, ParallelLMHead)
  35. from aphrodite.modeling.megatron.parallel_state import (
  36. get_tensor_model_parallel_world_size)
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  39. hf_model_weights_iterator)
  40. from aphrodite.common.sequence import SamplerOutput
  41. KVCache = Tuple[torch.Tensor, torch.Tensor]
  42. class GemmaMLP(nn.Module):
  43. def __init__(
  44. self,
  45. hidden_size: int,
  46. intermediate_size: int,
  47. linear_method: Optional[LinearMethodBase] = None,
  48. ) -> None:
  49. super().__init__()
  50. if linear_method is not None and not linear_method.quant_config.merge_weight(
  51. ):
  52. self.merge_weight = False
  53. self.gate_proj = ColumnParallelLinear(hidden_size,
  54. intermediate_size,
  55. bias=False,
  56. linear_method=linear_method)
  57. self.up_proj = ColumnParallelLinear(hidden_size,
  58. intermediate_size,
  59. bias=False,
  60. linear_method=linear_method)
  61. else:
  62. self.merge_weight = True
  63. self.gate_up_proj = MergedColumnParallelLinear(
  64. hidden_size, [intermediate_size] * 2,
  65. bias=False,
  66. linear_method=linear_method)
  67. self.down_proj = RowParallelLinear(intermediate_size,
  68. hidden_size,
  69. bias=False,
  70. linear_method=linear_method)
  71. self.act_fn = GeluAndMul()
  72. def forward(self, x):
  73. if self.merge_weight:
  74. gate_up, _ = self.gate_up_proj(x)
  75. else:
  76. up, _ = self.up_proj(x)
  77. gate, _ = self.gate_proj(x)
  78. gate_up = torch.cat([gate, up], dim=-1)
  79. x = self.act_fn(gate_up)
  80. x, _ = self.down_proj(x)
  81. return x
  82. class GemmaAttention(nn.Module):
  83. def __init__(self,
  84. hidden_size: int,
  85. num_heads: int,
  86. num_kv_heads: int,
  87. head_dim: int,
  88. max_position_embeddings: int = 8192,
  89. rope_theta: float = 10000,
  90. linear_method: Optional[LinearMethodBase] = None) -> None:
  91. super().__init__()
  92. self.hidden_size = hidden_size
  93. tp_size = get_tensor_model_parallel_world_size()
  94. self.total_num_heads = num_heads
  95. assert self.total_num_heads % tp_size == 0
  96. self.num_heads = self.total_num_heads // tp_size
  97. self.total_num_kv_heads = num_kv_heads
  98. if self.total_num_kv_heads >= tp_size:
  99. # Number of KV heads is greater than TP size, so we partition
  100. # the KV heads across multiple tensor parallel GPUs.
  101. assert self.total_num_kv_heads % tp_size == 0
  102. else:
  103. # Number of KV heads is less than TP size, so we replicate
  104. # the KV heads across multiple tensor parallel GPUs.
  105. assert tp_size % self.total_num_kv_heads == 0
  106. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  107. self.head_dim = head_dim
  108. self.q_size = self.num_heads * self.head_dim
  109. self.kv_size = self.num_kv_heads * self.head_dim
  110. self.scaling = self.head_dim**-0.5
  111. self.rope_theta = rope_theta
  112. if linear_method is not None and not linear_method.quant_config.merge_weight(
  113. ):
  114. self.merge_weight = False
  115. self.q_proj = ColumnParallelLinear(hidden_size,
  116. self.q_size,
  117. bias=False,
  118. linear_method=linear_method)
  119. self.k_proj = ColumnParallelLinear(hidden_size,
  120. self.kv_size,
  121. bias=False,
  122. linear_method=linear_method)
  123. self.v_proj = ColumnParallelLinear(hidden_size,
  124. self.kv_size,
  125. bias=False,
  126. linear_method=linear_method)
  127. else:
  128. self.merge_weight = True
  129. self.qkv_proj = QKVParallelLinear(
  130. hidden_size,
  131. self.head_dim,
  132. self.total_num_heads,
  133. self.total_num_kv_heads,
  134. bias=False,
  135. linear_method=linear_method,
  136. )
  137. self.o_proj = RowParallelLinear(
  138. self.total_num_heads * self.head_dim,
  139. hidden_size,
  140. bias=False,
  141. linear_method=linear_method,
  142. )
  143. self.rotary_emb = get_rope(
  144. self.head_dim,
  145. rotary_dim=self.head_dim,
  146. max_position=max_position_embeddings,
  147. base=self.rope_theta,
  148. is_neox_style=True,
  149. )
  150. self.attn = PagedAttention(self.num_heads,
  151. self.head_dim,
  152. self.scaling,
  153. num_kv_heads=self.num_kv_heads)
  154. def forward(
  155. self,
  156. positions: torch.Tensor,
  157. hidden_states: torch.Tensor,
  158. kv_cache: KVCache,
  159. input_metadata: InputMetadata,
  160. ) -> torch.Tensor:
  161. if self.merge_weight:
  162. qkv, _ = self.qkv_proj(hidden_states)
  163. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  164. dim=-1)
  165. else:
  166. q, _ = self.q_proj(hidden_states)
  167. k, _ = self.k_proj(hidden_states)
  168. v, _ = self.v_proj(hidden_states)
  169. q, k = self.rotary_emb(positions, q, k)
  170. k_cache, v_cache = kv_cache
  171. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  172. output, _ = self.o_proj(attn_output)
  173. return output
  174. class GemmaDecoderLayer(nn.Module):
  175. def __init__(
  176. self,
  177. config: GemmaConfig,
  178. linear_method: Optional[LinearMethodBase] = None,
  179. ) -> None:
  180. super().__init__()
  181. self.hidden_size = config.hidden_size
  182. self.self_attn = GemmaAttention(
  183. hidden_size=self.hidden_size,
  184. num_heads=config.num_attention_heads,
  185. num_kv_heads=config.num_key_value_heads,
  186. head_dim=config.head_dim,
  187. max_position_embeddings=config.max_position_embeddings,
  188. rope_theta=config.rope_theta,
  189. linear_method=linear_method,
  190. )
  191. self.mlp = GemmaMLP(
  192. hidden_size=self.hidden_size,
  193. intermediate_size=config.intermediate_size,
  194. linear_method=linear_method,
  195. )
  196. self.input_layernorm = RMSNorm(config.hidden_size,
  197. eps=config.rms_norm_eps)
  198. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  199. eps=config.rms_norm_eps)
  200. def forward(
  201. self,
  202. positions: torch.Tensor,
  203. hidden_states: torch.Tensor,
  204. kv_cache: KVCache,
  205. input_metadata: InputMetadata,
  206. residual: Optional[torch.Tensor],
  207. ) -> Tuple[torch.Tensor, torch.Tensor]:
  208. # Self Attention
  209. if residual is None:
  210. residual = hidden_states
  211. hidden_states = self.input_layernorm(hidden_states)
  212. else:
  213. hidden_states, residual = self.input_layernorm(
  214. hidden_states, residual)
  215. hidden_states = self.self_attn(
  216. positions=positions,
  217. hidden_states=hidden_states,
  218. kv_cache=kv_cache,
  219. input_metadata=input_metadata,
  220. )
  221. # Fully Connected
  222. hidden_states, residual = self.post_attention_layernorm(
  223. hidden_states, residual)
  224. hidden_states = self.mlp(hidden_states)
  225. return hidden_states, residual
  226. class GemmaModel(nn.Module):
  227. def __init__(
  228. self,
  229. config: GemmaConfig,
  230. linear_method: Optional[LinearMethodBase] = None,
  231. ) -> None:
  232. super().__init__()
  233. self.config = config
  234. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  235. config.hidden_size,
  236. linear_method=linear_method)
  237. self.layers = nn.ModuleList([
  238. GemmaDecoderLayer(config, linear_method)
  239. for _ in range(config.num_hidden_layers)
  240. ])
  241. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  242. def forward(
  243. self,
  244. input_ids: torch.Tensor,
  245. positions: torch.Tensor,
  246. kv_caches: List[KVCache],
  247. input_metadata: InputMetadata,
  248. ) -> torch.Tensor:
  249. hidden_states = self.embed_tokens(input_ids)
  250. # Normalize the embedding by sqrt(hidden_size)
  251. hidden_states *= self.config.hidden_size**0.5
  252. residual = None
  253. for i in range(len(self.layers)):
  254. layer = self.layers[i]
  255. hidden_states, residual = layer(
  256. positions,
  257. hidden_states,
  258. kv_caches[i],
  259. input_metadata,
  260. residual,
  261. )
  262. hidden_states, _ = self.norm(hidden_states, residual)
  263. return hidden_states
  264. class GemmaForCausalLM(nn.Module):
  265. def __init__(
  266. self,
  267. config: GemmaConfig,
  268. linear_method: Optional[LinearMethodBase] = None,
  269. ) -> None:
  270. super().__init__()
  271. self.config = config
  272. self.linear_method = linear_method
  273. self.model = GemmaModel(config, linear_method)
  274. self.lm_head = ParallelLMHead(config.vocab_size,
  275. config.hidden_size,
  276. linear_method=linear_method)
  277. self.sampler = Sampler(config.vocab_size)
  278. @torch.no_grad()
  279. def forward(
  280. self,
  281. input_ids: torch.Tensor,
  282. positions: torch.Tensor,
  283. kv_caches: List[KVCache],
  284. input_metadata: InputMetadata,
  285. ) -> torch.Tensor:
  286. hidden_states = self.model(input_ids, positions, kv_caches,
  287. input_metadata)
  288. return hidden_states
  289. def sample(
  290. self,
  291. hidden_states: torch.Tensor,
  292. sampling_metadata: SamplingMetadata,
  293. ) -> Optional[SamplerOutput]:
  294. next_tokens = self.sampler(self.lm_head(hidden_states),
  295. sampling_metadata)
  296. return next_tokens
  297. def load_weights(self,
  298. model_name_or_path: str,
  299. cache_dir: Optional[str] = None,
  300. load_format: str = "auto",
  301. revision: Optional[str] = None):
  302. stacked_params_mapping = [
  303. # (param_name, shard_name, shard_id)
  304. ("qkv_proj", "q_proj", "q"),
  305. ("qkv_proj", "k_proj", "k"),
  306. ("qkv_proj", "v_proj", "v"),
  307. ("gate_up_proj", "gate_proj", 0),
  308. ("gate_up_proj", "up_proj", 1),
  309. ]
  310. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  311. ):
  312. stacked_params_mapping = []
  313. params_dict = dict(self.named_parameters())
  314. loaded_params = set()
  315. for name, loaded_weight in hf_model_weights_iterator(
  316. model_name_or_path, cache_dir, load_format, revision,
  317. self.config):
  318. if "rotary_emb.inv_freq" in name:
  319. continue
  320. if "embed_tokens.weight" in name:
  321. # Copy word embedding to lm_head
  322. loaded_params.add("lm_head.weight")
  323. lm_head_param = params_dict["lm_head.weight"]
  324. weight_loader = getattr(lm_head_param, "weight_loader",
  325. default_weight_loader)
  326. weight_loader(lm_head_param, loaded_weight)
  327. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  328. if weight_name not in name:
  329. continue
  330. name = name.replace(weight_name, param_name)
  331. # Skip loading extra bias for GPTQ models.
  332. if name.endswith(".bias") and name not in params_dict:
  333. continue
  334. param = params_dict[name]
  335. weight_loader = param.weight_loader
  336. weight_loader(param, loaded_weight, shard_id)
  337. break
  338. else:
  339. # Skip loading extra layer for lora models.
  340. if "lm_head" in name:
  341. continue
  342. # GemmaRMSNorm is different from Llama's in that it multiplies
  343. # (1 + weight) to the output, instead of just weight.
  344. if "norm.weight" in name:
  345. loaded_weight += 1.0
  346. param = params_dict[name]
  347. weight_loader = getattr(param, "weight_loader",
  348. default_weight_loader)
  349. weight_loader(param, loaded_weight)
  350. loaded_params.add(name)
  351. unloaded_params = params_dict.keys() - loaded_params
  352. if unloaded_params:
  353. raise RuntimeError(
  354. "Some weights are not initialized from checkpoints: "
  355. f"{unloaded_params}")