gemma.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # coding=utf-8
  2. # Copyright 2023 The vLLM team.
  3. # Copyright (c) Google Inc.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Inference-only Gemma model compatible with HuggingFace weights."""
  17. from functools import lru_cache
  18. from typing import Iterable, List, Optional, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import GemmaConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import LoRAConfig
  25. from aphrodite.common.sequence import SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import GeluAndMul
  28. from aphrodite.modeling.layers.layernorm import RMSNorm
  29. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  36. VocabParallelEmbedding
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. @lru_cache(maxsize=None)
  41. def _get_gemma_act_fn(
  42. hidden_act: Optional[str],
  43. hidden_activation: Optional[str],
  44. ) -> nn.Module:
  45. if hidden_activation is None:
  46. if hidden_act is not None:
  47. logger.warning(
  48. "Gemma's activation function was incorrectly set to exact GeLU "
  49. "in the config JSON file when it was initially released. "
  50. "Changing the activation function to approximate GeLU "
  51. "(`gelu_pytorch_tanh`). If you want to use the legacy "
  52. "`%s`, edit the config JSON to set "
  53. "`hidden_activation=%s` instead of `hidden_act`. "
  54. "See https://github.com/huggingface/transformers/pull/29402 "
  55. "for more details.", hidden_act, hidden_act)
  56. return GeluAndMul(approximate="tanh")
  57. elif hidden_activation == "gelu_pytorch_tanh":
  58. return GeluAndMul(approximate="tanh")
  59. elif hidden_activation == "gelu":
  60. return GeluAndMul(approximate="none")
  61. else:
  62. raise ValueError(f"Activation function {hidden_act} is not "
  63. "supported for Gemma models.")
  64. class GemmaMLP(nn.Module):
  65. def __init__(
  66. self,
  67. hidden_size: int,
  68. intermediate_size: int,
  69. hidden_act: Optional[str] = None,
  70. hidden_activation: Optional[str] = None,
  71. quant_config: Optional[QuantizationConfig] = None,
  72. ) -> None:
  73. super().__init__()
  74. self.gate_up_proj = MergedColumnParallelLinear(
  75. hidden_size, [intermediate_size] * 2,
  76. bias=False,
  77. quant_config=quant_config)
  78. self.down_proj = RowParallelLinear(intermediate_size,
  79. hidden_size,
  80. bias=False,
  81. quant_config=quant_config)
  82. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  83. def forward(self, x):
  84. gate_up, _ = self.gate_up_proj(x)
  85. x = self.act_fn(gate_up)
  86. x, _ = self.down_proj(x)
  87. return x
  88. class GemmaAttention(nn.Module):
  89. def __init__(self,
  90. hidden_size: int,
  91. num_heads: int,
  92. num_kv_heads: int,
  93. head_dim: int,
  94. max_position_embeddings: int = 8192,
  95. rope_theta: float = 10000,
  96. quant_config: Optional[QuantizationConfig] = None) -> None:
  97. super().__init__()
  98. self.hidden_size = hidden_size
  99. tp_size = get_tensor_model_parallel_world_size()
  100. self.total_num_heads = num_heads
  101. assert self.total_num_heads % tp_size == 0
  102. self.num_heads = self.total_num_heads // tp_size
  103. self.total_num_kv_heads = num_kv_heads
  104. if self.total_num_kv_heads >= tp_size:
  105. # Number of KV heads is greater than TP size, so we partition
  106. # the KV heads across multiple tensor parallel GPUs.
  107. assert self.total_num_kv_heads % tp_size == 0
  108. else:
  109. # Number of KV heads is less than TP size, so we replicate
  110. # the KV heads across multiple tensor parallel GPUs.
  111. assert tp_size % self.total_num_kv_heads == 0
  112. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  113. self.head_dim = head_dim
  114. self.q_size = self.num_heads * self.head_dim
  115. self.kv_size = self.num_kv_heads * self.head_dim
  116. self.scaling = self.head_dim**-0.5
  117. self.rope_theta = rope_theta
  118. self.qkv_proj = QKVParallelLinear(
  119. hidden_size,
  120. self.head_dim,
  121. self.total_num_heads,
  122. self.total_num_kv_heads,
  123. bias=False,
  124. quant_config=quant_config,
  125. )
  126. self.o_proj = RowParallelLinear(
  127. self.total_num_heads * self.head_dim,
  128. hidden_size,
  129. bias=False,
  130. quant_config=quant_config,
  131. )
  132. self.rotary_emb = get_rope(
  133. self.head_dim,
  134. rotary_dim=self.head_dim,
  135. max_position=max_position_embeddings,
  136. base=self.rope_theta,
  137. is_neox_style=True,
  138. )
  139. self.attn = Attention(self.num_heads,
  140. self.head_dim,
  141. self.scaling,
  142. num_kv_heads=self.num_kv_heads)
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> torch.Tensor:
  150. qkv, _ = self.qkv_proj(hidden_states)
  151. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  152. q, k = self.rotary_emb(positions, q, k)
  153. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  154. output, _ = self.o_proj(attn_output)
  155. return output
  156. class GemmaDecoderLayer(nn.Module):
  157. def __init__(
  158. self,
  159. config: GemmaConfig,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = config.hidden_size
  164. self.self_attn = GemmaAttention(
  165. hidden_size=self.hidden_size,
  166. num_heads=config.num_attention_heads,
  167. num_kv_heads=config.num_key_value_heads,
  168. head_dim=config.head_dim,
  169. max_position_embeddings=config.max_position_embeddings,
  170. rope_theta=config.rope_theta,
  171. quant_config=quant_config,
  172. )
  173. self.mlp = GemmaMLP(
  174. hidden_size=self.hidden_size,
  175. intermediate_size=config.intermediate_size,
  176. hidden_act=config.hidden_act,
  177. hidden_activation=getattr(config, "hidden_activation", None),
  178. quant_config=quant_config,
  179. )
  180. self.input_layernorm = RMSNorm(config.hidden_size,
  181. eps=config.rms_norm_eps)
  182. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  183. eps=config.rms_norm_eps)
  184. def forward(
  185. self,
  186. positions: torch.Tensor,
  187. hidden_states: torch.Tensor,
  188. kv_cache: torch.Tensor,
  189. attn_metadata: AttentionMetadata,
  190. residual: Optional[torch.Tensor],
  191. ) -> Tuple[torch.Tensor, torch.Tensor]:
  192. # Self Attention
  193. if residual is None:
  194. residual = hidden_states
  195. hidden_states = self.input_layernorm(hidden_states)
  196. else:
  197. hidden_states, residual = self.input_layernorm(
  198. hidden_states, residual)
  199. hidden_states = self.self_attn(
  200. positions=positions,
  201. hidden_states=hidden_states,
  202. kv_cache=kv_cache,
  203. attn_metadata=attn_metadata,
  204. )
  205. # Fully Connected
  206. hidden_states, residual = self.post_attention_layernorm(
  207. hidden_states, residual)
  208. hidden_states = self.mlp(hidden_states)
  209. return hidden_states, residual
  210. class GemmaModel(nn.Module):
  211. def __init__(
  212. self,
  213. config: GemmaConfig,
  214. quant_config: Optional[QuantizationConfig] = None,
  215. ) -> None:
  216. super().__init__()
  217. self.config = config
  218. self.embed_tokens = VocabParallelEmbedding(
  219. config.vocab_size,
  220. config.hidden_size,
  221. )
  222. self.layers = nn.ModuleList([
  223. GemmaDecoderLayer(config, quant_config)
  224. for _ in range(config.num_hidden_layers)
  225. ])
  226. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  227. # Normalize the embedding by sqrt(hidden_size)
  228. # The normalizer's data type should be downcasted to the model's
  229. # data type such as bfloat16, not float32.
  230. # See https://github.com/huggingface/transformers/pull/29402
  231. normalizer = self.config.hidden_size**0.5
  232. self.register_buffer("normalizer", torch.tensor(normalizer))
  233. def forward(
  234. self,
  235. input_ids: torch.Tensor,
  236. positions: torch.Tensor,
  237. kv_caches: List[torch.Tensor],
  238. attn_metadata: AttentionMetadata,
  239. ) -> torch.Tensor:
  240. hidden_states = self.embed_tokens(input_ids)
  241. hidden_states *= self.normalizer
  242. residual = None
  243. for i in range(len(self.layers)):
  244. layer = self.layers[i]
  245. hidden_states, residual = layer(
  246. positions,
  247. hidden_states,
  248. kv_caches[i],
  249. attn_metadata,
  250. residual,
  251. )
  252. hidden_states, _ = self.norm(hidden_states, residual)
  253. return hidden_states
  254. class GemmaForCausalLM(nn.Module):
  255. packed_modules_mapping = {
  256. "qkv_proj": [
  257. "q_proj",
  258. "k_proj",
  259. "v_proj",
  260. ],
  261. "gate_up_proj": [
  262. "gate_proj",
  263. "up_proj",
  264. ],
  265. }
  266. # LoRA specific attributes
  267. supported_lora_modules = [
  268. "qkv_proj",
  269. "o_proj",
  270. "gate_up_proj",
  271. "down_proj",
  272. ]
  273. # Gemma does not apply LoRA to the embedding layer.
  274. embedding_modules = {}
  275. embedding_padding_modules = []
  276. def __init__(
  277. self,
  278. config: GemmaConfig,
  279. quant_config: Optional[QuantizationConfig] = None,
  280. lora_config: Optional[LoRAConfig] = None,
  281. ) -> None:
  282. del lora_config # Unused.
  283. super().__init__()
  284. self.config = config
  285. self.quant_config = quant_config
  286. self.model = GemmaModel(config, quant_config)
  287. self.logits_processor = LogitsProcessor(config.vocab_size)
  288. self.sampler = Sampler()
  289. @torch.no_grad()
  290. def forward(
  291. self,
  292. input_ids: torch.Tensor,
  293. positions: torch.Tensor,
  294. kv_caches: List[torch.Tensor],
  295. attn_metadata: AttentionMetadata,
  296. ) -> torch.Tensor:
  297. hidden_states = self.model(input_ids, positions, kv_caches,
  298. attn_metadata)
  299. return hidden_states
  300. def compute_logits(self, hidden_states: torch.Tensor,
  301. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  302. logits = self.logits_processor(self.model.embed_tokens.weight,
  303. hidden_states, sampling_metadata)
  304. return logits
  305. def sample(
  306. self,
  307. logits: torch.Tensor,
  308. sampling_metadata: SamplingMetadata,
  309. ) -> Optional[SamplerOutput]:
  310. next_tokens = self.sampler(logits, sampling_metadata)
  311. return next_tokens
  312. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  313. stacked_params_mapping = [
  314. # (param_name, shard_name, shard_id)
  315. ("qkv_proj", "q_proj", "q"),
  316. ("qkv_proj", "k_proj", "k"),
  317. ("qkv_proj", "v_proj", "v"),
  318. ("gate_up_proj", "gate_proj", 0),
  319. ("gate_up_proj", "up_proj", 1),
  320. ]
  321. params_dict = dict(self.named_parameters())
  322. loaded_params = set()
  323. for name, loaded_weight in weights:
  324. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  325. if shard_name not in name:
  326. continue
  327. name = name.replace(shard_name, param_name)
  328. # Skip loading extra bias for GPTQ models.
  329. if name.endswith(".bias") and name not in params_dict:
  330. continue
  331. param = params_dict[name]
  332. weight_loader = param.weight_loader
  333. weight_loader(param, loaded_weight, shard_id)
  334. break
  335. else:
  336. # lm_head is not used in vllm as it is tied with embed_token.
  337. # To prevent errors, skip loading lm_head.weight.
  338. if "lm_head.weight" in name:
  339. continue
  340. # Skip loading extra bias for GPTQ models.
  341. if name.endswith(".bias") and name not in params_dict:
  342. continue
  343. # GemmaRMSNorm is different from Llama's in that it multiplies
  344. # (1 + weight) to the output, instead of just weight.
  345. if "norm.weight" in name:
  346. loaded_weight += 1.0
  347. param = params_dict[name]
  348. weight_loader = getattr(param, "weight_loader",
  349. default_weight_loader)
  350. weight_loader(param, loaded_weight)
  351. loaded_params.add(name)
  352. unloaded_params = params_dict.keys() - loaded_params
  353. if unloaded_params:
  354. raise RuntimeError(
  355. "Some weights are not initialized from checkpoints: "
  356. f"{unloaded_params}")