gemma.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # coding=utf-8
  2. # Copyright 2023 The PygmalionAI team.
  3. # Copyright 2023 The vLLM team.
  4. # Copyright (c) Google Inc.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Inference-only Gemma model compatible with HuggingFace weights."""
  18. from functools import lru_cache
  19. from typing import List, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from transformers import GemmaConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.modeling.layers.activation import GeluAndMul
  25. from aphrodite.modeling.layers.layernorm import RMSNorm
  26. from aphrodite.modeling.layers.linear import (
  27. LinearMethodBase,
  28. MergedColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear,
  31. ColumnParallelLinear,
  32. )
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding,
  38. ParallelLMHead,
  39. )
  40. from aphrodite.distributed import (
  41. get_tensor_model_parallel_world_size, )
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.modeling.hf_downloader import (
  44. default_weight_loader,
  45. hf_model_weights_iterator,
  46. )
  47. from aphrodite.common.sequence import SamplerOutput
  48. @lru_cache(maxsize=None)
  49. def _get_gemma_act_fn(
  50. hidden_act: Optional[str],
  51. hidden_activation: Optional[str],
  52. ) -> nn.Module:
  53. if hidden_activation is None:
  54. if hidden_act is not None:
  55. hidden_activation = hidden_act
  56. return GeluAndMul(approximate="none")
  57. elif hidden_activation == "gelu_pytorch_tanh":
  58. return GeluAndMul(approximate="tanh")
  59. elif hidden_activation == "gelu":
  60. return GeluAndMul(approximate="none")
  61. else:
  62. raise ValueError(f"Activation function {hidden_act} is not "
  63. "supported for Gemma models.")
  64. class GemmaMLP(nn.Module):
  65. def __init__(
  66. self,
  67. hidden_size: int,
  68. intermediate_size: int,
  69. hidden_act: Optional[str] = None,
  70. hidden_activation: Optional[str] = None,
  71. linear_method: Optional[LinearMethodBase] = None,
  72. ) -> None:
  73. super().__init__()
  74. if (linear_method is not None
  75. and not linear_method.quant_config.merge_weight()):
  76. self.merge_weight = False
  77. self.gate_proj = ColumnParallelLinear(
  78. hidden_size,
  79. intermediate_size,
  80. bias=False,
  81. linear_method=linear_method,
  82. )
  83. self.up_proj = ColumnParallelLinear(
  84. hidden_size,
  85. intermediate_size,
  86. bias=False,
  87. linear_method=linear_method,
  88. )
  89. else:
  90. self.merge_weight = True
  91. self.gate_up_proj = MergedColumnParallelLinear(
  92. hidden_size,
  93. [intermediate_size] * 2,
  94. bias=False,
  95. linear_method=linear_method,
  96. )
  97. self.down_proj = RowParallelLinear(
  98. intermediate_size,
  99. hidden_size,
  100. bias=False,
  101. linear_method=linear_method,
  102. )
  103. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  104. def forward(self, x):
  105. if self.merge_weight:
  106. gate_up, _ = self.gate_up_proj(x)
  107. else:
  108. up, _ = self.up_proj(x)
  109. gate, _ = self.gate_proj(x)
  110. gate_up = torch.cat([gate, up], dim=-1)
  111. x = self.act_fn(gate_up)
  112. x, _ = self.down_proj(x)
  113. return x
  114. class GemmaAttention(nn.Module):
  115. def __init__(
  116. self,
  117. hidden_size: int,
  118. num_heads: int,
  119. num_kv_heads: int,
  120. head_dim: int,
  121. max_position_embeddings: int = 8192,
  122. rope_theta: float = 10000,
  123. linear_method: Optional[LinearMethodBase] = None,
  124. ) -> None:
  125. super().__init__()
  126. self.hidden_size = hidden_size
  127. tp_size = get_tensor_model_parallel_world_size()
  128. self.total_num_heads = num_heads
  129. assert self.total_num_heads % tp_size == 0
  130. self.num_heads = self.total_num_heads // tp_size
  131. self.total_num_kv_heads = num_kv_heads
  132. if self.total_num_kv_heads >= tp_size:
  133. # Number of KV heads is greater than TP size, so we partition
  134. # the KV heads across multiple tensor parallel GPUs.
  135. assert self.total_num_kv_heads % tp_size == 0
  136. else:
  137. # Number of KV heads is less than TP size, so we replicate
  138. # the KV heads across multiple tensor parallel GPUs.
  139. assert tp_size % self.total_num_kv_heads == 0
  140. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  141. self.head_dim = head_dim
  142. self.q_size = self.num_heads * self.head_dim
  143. self.kv_size = self.num_kv_heads * self.head_dim
  144. self.scaling = self.head_dim**-0.5
  145. self.rope_theta = rope_theta
  146. if (linear_method is not None
  147. and not linear_method.quant_config.merge_weight()):
  148. self.merge_weight = False
  149. self.q_proj = ColumnParallelLinear(
  150. hidden_size,
  151. self.total_num_heads * self.head_dim,
  152. bias=False,
  153. linear_method=linear_method,
  154. )
  155. self.k_proj = ColumnParallelLinear(
  156. hidden_size,
  157. self.total_num_kv_heads * self.head_dim,
  158. bias=False,
  159. linear_method=linear_method,
  160. )
  161. self.v_proj = ColumnParallelLinear(
  162. hidden_size,
  163. self.total_num_kv_heads * self.head_dim,
  164. bias=False,
  165. linear_method=linear_method,
  166. )
  167. else:
  168. self.merge_weight = True
  169. self.qkv_proj = QKVParallelLinear(
  170. hidden_size,
  171. self.head_dim,
  172. self.total_num_heads,
  173. self.total_num_kv_heads,
  174. bias=False,
  175. linear_method=linear_method,
  176. )
  177. self.o_proj = RowParallelLinear(
  178. self.total_num_heads * self.head_dim,
  179. hidden_size,
  180. bias=False,
  181. linear_method=linear_method,
  182. )
  183. is_neox_style = (True if linear_method is None
  184. or linear_method.quant_config.rope_style() is None
  185. else linear_method.quant_config.rope_style())
  186. self.rotary_emb = get_rope(
  187. self.head_dim,
  188. rotary_dim=self.head_dim,
  189. max_position=max_position_embeddings,
  190. base=self.rope_theta,
  191. is_neox_style=is_neox_style,
  192. )
  193. self.attn = Attention(
  194. self.num_heads,
  195. self.head_dim,
  196. self.scaling,
  197. num_kv_heads=self.num_kv_heads,
  198. )
  199. def forward(
  200. self,
  201. positions: torch.Tensor,
  202. hidden_states: torch.Tensor,
  203. kv_cache: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. ) -> torch.Tensor:
  206. if self.merge_weight:
  207. qkv, _ = self.qkv_proj(hidden_states)
  208. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  209. dim=-1)
  210. else:
  211. q, _ = self.q_proj(hidden_states)
  212. k, _ = self.k_proj(hidden_states)
  213. v, _ = self.v_proj(hidden_states)
  214. q, k = self.rotary_emb(positions, q, k)
  215. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  216. output, _ = self.o_proj(attn_output)
  217. return output
  218. class GemmaDecoderLayer(nn.Module):
  219. def __init__(
  220. self,
  221. config: GemmaConfig,
  222. linear_method: Optional[LinearMethodBase] = None,
  223. ) -> None:
  224. super().__init__()
  225. self.hidden_size = config.hidden_size
  226. self.self_attn = GemmaAttention(
  227. hidden_size=self.hidden_size,
  228. num_heads=config.num_attention_heads,
  229. num_kv_heads=config.num_key_value_heads,
  230. head_dim=config.head_dim,
  231. max_position_embeddings=config.max_position_embeddings,
  232. rope_theta=config.rope_theta,
  233. linear_method=linear_method,
  234. )
  235. self.mlp = GemmaMLP(
  236. hidden_size=self.hidden_size,
  237. intermediate_size=config.intermediate_size,
  238. hidden_act=config.hidden_act,
  239. hidden_activation=getattr(config, "hidden_activation", None),
  240. linear_method=linear_method,
  241. )
  242. self.input_layernorm = RMSNorm(config.hidden_size,
  243. eps=config.rms_norm_eps)
  244. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  245. eps=config.rms_norm_eps)
  246. def forward(
  247. self,
  248. positions: torch.Tensor,
  249. hidden_states: torch.Tensor,
  250. kv_cache: torch.Tensor,
  251. attn_metadata: AttentionMetadata,
  252. residual: Optional[torch.Tensor],
  253. ) -> Tuple[torch.Tensor, torch.Tensor]:
  254. # Self Attention
  255. if residual is None:
  256. residual = hidden_states
  257. hidden_states = self.input_layernorm(hidden_states)
  258. else:
  259. hidden_states, residual = self.input_layernorm(
  260. hidden_states, residual)
  261. hidden_states = self.self_attn(
  262. positions=positions,
  263. hidden_states=hidden_states,
  264. kv_cache=kv_cache,
  265. attn_metadata=attn_metadata,
  266. )
  267. # Fully Connected
  268. hidden_states, residual = self.post_attention_layernorm(
  269. hidden_states, residual)
  270. hidden_states = self.mlp(hidden_states)
  271. return hidden_states, residual
  272. class GemmaModel(nn.Module):
  273. def __init__(
  274. self,
  275. config: GemmaConfig,
  276. linear_method: Optional[LinearMethodBase] = None,
  277. ) -> None:
  278. super().__init__()
  279. self.config = config
  280. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  281. config.hidden_size,
  282. linear_method=linear_method)
  283. self.layers = nn.ModuleList([
  284. GemmaDecoderLayer(config, linear_method)
  285. for _ in range(config.num_hidden_layers)
  286. ])
  287. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  288. # Normalize the embedding by sqrt(hidden_size)
  289. # The normalizer's data type should be downcasted to the model's
  290. # data type such as bfloat16, not float32.
  291. # See https://github.com/huggingface/transformers/pull/29402
  292. normalizer = self.config.hidden_size**0.5
  293. self.register_buffer("normalizer", torch.tensor(normalizer))
  294. def forward(
  295. self,
  296. input_ids: torch.Tensor,
  297. positions: torch.Tensor,
  298. kv_caches: List[torch.Tensor],
  299. attn_metadata: AttentionMetadata,
  300. ) -> torch.Tensor:
  301. hidden_states = self.embed_tokens(input_ids)
  302. hidden_states *= self.normalizer
  303. residual = None
  304. for i in range(len(self.layers)):
  305. layer = self.layers[i]
  306. hidden_states, residual = layer(
  307. positions,
  308. hidden_states,
  309. kv_caches[i],
  310. attn_metadata,
  311. residual,
  312. )
  313. hidden_states, _ = self.norm(hidden_states, residual)
  314. return hidden_states
  315. class GemmaForCausalLM(nn.Module):
  316. def __init__(
  317. self,
  318. config: GemmaConfig,
  319. linear_method: Optional[LinearMethodBase] = None,
  320. ) -> None:
  321. super().__init__()
  322. self.config = config
  323. self.linear_method = linear_method
  324. self.model = GemmaModel(config, linear_method)
  325. self.lm_head = ParallelLMHead(config.vocab_size,
  326. config.hidden_size,
  327. linear_method=linear_method)
  328. self.logits_processor = LogitsProcessor(config.vocab_size,
  329. config.tokenizer_vocab_size)
  330. self.sampler = Sampler()
  331. @torch.no_grad()
  332. def forward(
  333. self,
  334. input_ids: torch.Tensor,
  335. positions: torch.Tensor,
  336. kv_caches: List[torch.Tensor],
  337. attn_metadata: AttentionMetadata,
  338. ) -> torch.Tensor:
  339. hidden_states = self.model(input_ids, positions, kv_caches,
  340. attn_metadata)
  341. return hidden_states
  342. def compute_logits(self, hidden_states: torch.Tensor,
  343. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  344. logits = self.logits_processor(self.lm_head, hidden_states,
  345. sampling_metadata)
  346. return logits
  347. def sample(
  348. self,
  349. logits: torch.Tensor,
  350. sampling_metadata: SamplingMetadata,
  351. ) -> Optional[SamplerOutput]:
  352. next_tokens = self.sampler(logits, sampling_metadata)
  353. return next_tokens
  354. def load_weights(
  355. self,
  356. model_name_or_path: str,
  357. cache_dir: Optional[str] = None,
  358. load_format: str = "auto",
  359. revision: Optional[str] = None,
  360. ):
  361. stacked_params_mapping = [
  362. # (param_name, shard_name, shard_id)
  363. ("qkv_proj", "q_proj", "q"),
  364. ("qkv_proj", "k_proj", "k"),
  365. ("qkv_proj", "v_proj", "v"),
  366. ("gate_up_proj", "gate_proj", 0),
  367. ("gate_up_proj", "up_proj", 1),
  368. ]
  369. if (self.linear_method is not None
  370. and not self.linear_method.quant_config.merge_weight()):
  371. stacked_params_mapping = []
  372. params_dict = dict(self.named_parameters())
  373. loaded_params = set()
  374. for name, loaded_weight in hf_model_weights_iterator(
  375. model_name_or_path, cache_dir, load_format, revision,
  376. self.config):
  377. if "rotary_emb.inv_freq" in name:
  378. continue
  379. if "embed_tokens" in name:
  380. # Copy word embedding to lm_head
  381. head_name = name.replace("model.embed_tokens", "lm_head")
  382. if head_name in params_dict:
  383. loaded_params.add(head_name)
  384. lm_head_param = params_dict[head_name]
  385. weight_loader = getattr(lm_head_param, "weight_loader",
  386. default_weight_loader)
  387. weight_loader(lm_head_param, loaded_weight)
  388. for param_name, weight_name, shard_id in stacked_params_mapping:
  389. if weight_name not in name:
  390. continue
  391. name = name.replace(weight_name, param_name)
  392. # Skip loading extra bias for GPTQ models.
  393. if name.endswith(".bias") and name not in params_dict:
  394. continue
  395. param = params_dict[name]
  396. weight_loader = param.weight_loader
  397. weight_loader(param, loaded_weight, shard_id)
  398. break
  399. else:
  400. if "lm_head.weight" in name:
  401. continue
  402. # Skip loading extra layer for lora models.
  403. if "lm_head" in name and name not in params_dict:
  404. continue
  405. # GemmaRMSNorm is different from Llama's in that it multiplies
  406. # (1 + weight) to the output, instead of just weight.
  407. if "norm.weight" in name:
  408. loaded_weight += 1.0
  409. param = params_dict[name]
  410. weight_loader = getattr(param, "weight_loader",
  411. default_weight_loader)
  412. weight_loader(param, loaded_weight)
  413. loaded_params.add(name)
  414. unloaded_params = params_dict.keys() - loaded_params
  415. if unloaded_params:
  416. raise RuntimeError(
  417. "Some weights are not initialized from checkpoints: "
  418. f"{unloaded_params}")