gemma2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  5. #
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. from typing import Iterable, List, Optional, Set, Tuple
  19. import torch
  20. from torch import nn
  21. from transformers import Gemma2Config
  22. from aphrodite.attention import Attention, AttentionMetadata
  23. from aphrodite.common.config import CacheConfig, LoRAConfig
  24. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  25. from aphrodite.distributed import get_tensor_model_parallel_world_size
  26. from aphrodite.modeling.layers.activation import GeluAndMul
  27. from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
  28. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear)
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  35. VocabParallelEmbedding
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. from .interfaces import SupportsLoRA
  40. class Gemma2MLP(nn.Module):
  41. def __init__(
  42. self,
  43. hidden_size: int,
  44. intermediate_size: int,
  45. hidden_act: str,
  46. hidden_activation: str,
  47. quant_config: Optional[QuantizationConfig] = None,
  48. ) -> None:
  49. super().__init__()
  50. self.gate_up_proj = MergedColumnParallelLinear(
  51. hidden_size, [intermediate_size] * 2,
  52. bias=False,
  53. quant_config=quant_config)
  54. self.down_proj = RowParallelLinear(intermediate_size,
  55. hidden_size,
  56. bias=False,
  57. quant_config=quant_config)
  58. if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
  59. raise ValueError(
  60. "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
  61. "function. Please set `hidden_act` and `hidden_activation` to "
  62. "`gelu_pytorch_tanh`.")
  63. self.act_fn = GeluAndMul(approximate="tanh")
  64. def forward(self, x: torch.Tensor) -> torch.Tensor:
  65. gate_up, _ = self.gate_up_proj(x)
  66. x = self.act_fn(gate_up)
  67. x, _ = self.down_proj(x)
  68. return x
  69. class Gemma2Attention(nn.Module):
  70. def __init__(self,
  71. layer_idx: int,
  72. config: Gemma2Config,
  73. hidden_size: int,
  74. num_heads: int,
  75. num_kv_heads: int,
  76. head_dim: int,
  77. max_position_embeddings: int,
  78. rope_theta: float,
  79. cache_config: Optional[CacheConfig] = None,
  80. quant_config: Optional[QuantizationConfig] = None) -> None:
  81. super().__init__()
  82. self.layer_idx = layer_idx
  83. self.config = config
  84. self.hidden_size = hidden_size
  85. tp_size = get_tensor_model_parallel_world_size()
  86. self.total_num_heads = num_heads
  87. assert self.total_num_heads % tp_size == 0
  88. self.num_heads = self.total_num_heads // tp_size
  89. self.total_num_kv_heads = num_kv_heads
  90. if self.total_num_kv_heads >= tp_size:
  91. # Number of KV heads is greater than TP size, so we partition
  92. # the KV heads across multiple tensor parallel GPUs.
  93. assert self.total_num_kv_heads % tp_size == 0
  94. else:
  95. # Number of KV heads is less than TP size, so we replicate
  96. # the KV heads across multiple tensor parallel GPUs.
  97. assert tp_size % self.total_num_kv_heads == 0
  98. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  99. self.head_dim = head_dim
  100. self.q_size = self.num_heads * self.head_dim
  101. self.kv_size = self.num_kv_heads * self.head_dim
  102. self.scaling = config.query_pre_attn_scalar**-0.5
  103. self.rope_theta = rope_theta
  104. self.qkv_proj = QKVParallelLinear(
  105. hidden_size,
  106. self.head_dim,
  107. self.total_num_heads,
  108. self.total_num_kv_heads,
  109. bias=config.attention_bias,
  110. quant_config=quant_config,
  111. )
  112. self.o_proj = RowParallelLinear(
  113. self.total_num_heads * self.head_dim,
  114. hidden_size,
  115. bias=config.attention_bias,
  116. quant_config=quant_config,
  117. )
  118. # TODO: Use the `get_rope` interface.
  119. self.rotary_emb = GemmaRotaryEmbedding(
  120. self.head_dim,
  121. self.head_dim,
  122. max_position_embeddings,
  123. base=self.rope_theta,
  124. is_neox_style=True,
  125. dtype=torch.get_default_dtype(),
  126. )
  127. # FIXME: While Gemma 2 uses sliding window attention for every
  128. # odd layer, Aphrodite currently ignores it and uses global attention
  129. # for all layers.
  130. use_sliding_window = (layer_idx % 2 == 1
  131. and config.sliding_window is not None)
  132. del use_sliding_window # Unused.
  133. self.attn = Attention(self.num_heads,
  134. self.head_dim,
  135. self.scaling,
  136. num_kv_heads=self.num_kv_heads,
  137. cache_config=cache_config,
  138. quant_config=quant_config)
  139. def forward(
  140. self,
  141. positions: torch.Tensor,
  142. hidden_states: torch.Tensor,
  143. kv_cache: torch.Tensor,
  144. attn_metadata: AttentionMetadata,
  145. ) -> torch.Tensor:
  146. qkv, _ = self.qkv_proj(hidden_states)
  147. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  148. q, k = self.rotary_emb(positions, q, k)
  149. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  150. output, _ = self.o_proj(attn_output)
  151. return output
  152. class Gemma2DecoderLayer(nn.Module):
  153. def __init__(
  154. self,
  155. layer_idx: int,
  156. config: Gemma2Config,
  157. cache_config: Optional[CacheConfig] = None,
  158. quant_config: Optional[QuantizationConfig] = None,
  159. ) -> None:
  160. super().__init__()
  161. self.hidden_size = config.hidden_size
  162. self.self_attn = Gemma2Attention(
  163. layer_idx=layer_idx,
  164. config=config,
  165. hidden_size=self.hidden_size,
  166. num_heads=config.num_attention_heads,
  167. num_kv_heads=config.num_key_value_heads,
  168. head_dim=config.head_dim,
  169. max_position_embeddings=config.max_position_embeddings,
  170. rope_theta=config.rope_theta,
  171. cache_config=cache_config,
  172. quant_config=quant_config,
  173. )
  174. self.hidden_size = config.hidden_size
  175. self.mlp = Gemma2MLP(
  176. hidden_size=self.hidden_size,
  177. intermediate_size=config.intermediate_size,
  178. hidden_act=config.hidden_act,
  179. hidden_activation=config.hidden_activation,
  180. quant_config=quant_config,
  181. )
  182. self.input_layernorm = GemmaRMSNorm(config.hidden_size,
  183. eps=config.rms_norm_eps)
  184. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
  185. eps=config.rms_norm_eps)
  186. self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  187. eps=config.rms_norm_eps)
  188. self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  189. eps=config.rms_norm_eps)
  190. def forward(
  191. self,
  192. positions: torch.Tensor,
  193. hidden_states: torch.Tensor,
  194. kv_cache: torch.Tensor,
  195. attn_metadata: AttentionMetadata,
  196. residual: Optional[torch.Tensor],
  197. ) -> Tuple[torch.Tensor, torch.Tensor]:
  198. if residual is None:
  199. residual = hidden_states
  200. hidden_states = self.input_layernorm(hidden_states)
  201. else:
  202. hidden_states, residual = self.input_layernorm(
  203. hidden_states, residual)
  204. hidden_states = self.self_attn(
  205. positions=positions,
  206. hidden_states=hidden_states,
  207. kv_cache=kv_cache,
  208. attn_metadata=attn_metadata,
  209. )
  210. hidden_states = self.post_attention_layernorm(hidden_states)
  211. hidden_states, residual = self.pre_feedforward_layernorm(
  212. hidden_states, residual)
  213. hidden_states = self.mlp(hidden_states)
  214. hidden_states = self.post_feedforward_layernorm(hidden_states)
  215. return hidden_states, residual
  216. class Gemma2Model(nn.Module):
  217. def __init__(
  218. self,
  219. config: Gemma2Config,
  220. cache_config: Optional[CacheConfig] = None,
  221. quant_config: Optional[QuantizationConfig] = None,
  222. ) -> None:
  223. super().__init__()
  224. self.config = config
  225. self.embed_tokens = VocabParallelEmbedding(
  226. config.vocab_size,
  227. config.hidden_size,
  228. )
  229. self.layers = nn.ModuleList([
  230. Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
  231. for layer_idx in range(config.num_hidden_layers)
  232. ])
  233. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  234. # Normalize the embedding by sqrt(hidden_size)
  235. # The normalizer's data type should be downcasted to the model's
  236. # data type such as bfloat16, not float32.
  237. # See https://github.com/huggingface/transformers/pull/29402
  238. normalizer = self.config.hidden_size**0.5
  239. self.register_buffer("normalizer", torch.tensor(normalizer))
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[torch.Tensor],
  245. attn_metadata: AttentionMetadata,
  246. ) -> torch.Tensor:
  247. hidden_states = self.embed_tokens(input_ids)
  248. hidden_states *= self.normalizer
  249. residual = None
  250. for i in range(len(self.layers)):
  251. layer = self.layers[i]
  252. hidden_states, residual = layer(
  253. positions,
  254. hidden_states,
  255. kv_caches[i],
  256. attn_metadata,
  257. residual,
  258. )
  259. hidden_states, _ = self.norm(hidden_states, residual)
  260. return hidden_states
  261. class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
  262. packed_modules_mapping = {
  263. "qkv_proj": [
  264. "q_proj",
  265. "k_proj",
  266. "v_proj",
  267. ],
  268. "gate_up_proj": [
  269. "gate_proj",
  270. "up_proj",
  271. ],
  272. }
  273. # LoRA specific attributes
  274. supported_lora_modules = [
  275. "qkv_proj",
  276. "o_proj",
  277. "gate_up_proj",
  278. "down_proj",
  279. ]
  280. # Gemma does not apply LoRA to the embedding layer.
  281. embedding_modules = {}
  282. embedding_padding_modules = []
  283. def __init__(
  284. self,
  285. config: Gemma2Config,
  286. cache_config: Optional[CacheConfig] = None,
  287. quant_config: Optional[QuantizationConfig] = None,
  288. lora_config: Optional[LoRAConfig] = None,
  289. ) -> None:
  290. del lora_config # Unused.
  291. super().__init__()
  292. self.config = config
  293. self.quant_config = quant_config
  294. self.model = Gemma2Model(config, cache_config, quant_config)
  295. self.logits_processor = LogitsProcessor(
  296. config.vocab_size, soft_cap=config.final_logit_softcapping)
  297. self.sampler = Sampler()
  298. def forward(
  299. self,
  300. input_ids: torch.Tensor,
  301. positions: torch.Tensor,
  302. kv_caches: List[torch.Tensor],
  303. attn_metadata: AttentionMetadata,
  304. intermediate_tensors: Optional[IntermediateTensors] = None,
  305. ) -> torch.Tensor:
  306. hidden_states = self.model(input_ids, positions, kv_caches,
  307. attn_metadata)
  308. return hidden_states
  309. def compute_logits(self, hidden_states: torch.Tensor,
  310. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  311. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  312. sampling_metadata)
  313. return logits
  314. def sample(
  315. self,
  316. logits: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[SamplerOutput]:
  319. next_tokens = self.sampler(logits, sampling_metadata)
  320. return next_tokens
  321. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  322. stacked_params_mapping = [
  323. # (param_name, shard_name, shard_id)
  324. ("qkv_proj", "q_proj", "q"),
  325. ("qkv_proj", "k_proj", "k"),
  326. ("qkv_proj", "v_proj", "v"),
  327. ("gate_up_proj", "gate_proj", 0),
  328. ("gate_up_proj", "up_proj", 1),
  329. ]
  330. params_dict = dict(self.named_parameters())
  331. loaded_params: Set[str] = set()
  332. for name, loaded_weight in weights:
  333. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  334. if shard_name not in name:
  335. continue
  336. name = name.replace(shard_name, param_name)
  337. # Skip loading extra bias for GPTQ models.
  338. if name.endswith(".bias") and name not in params_dict:
  339. continue
  340. param = params_dict[name]
  341. weight_loader = param.weight_loader
  342. weight_loader(param, loaded_weight, shard_id)
  343. break
  344. else:
  345. # lm_head is not used in vllm as it is tied with embed_token.
  346. # To prevent errors, skip loading lm_head.weight.
  347. if "lm_head.weight" in name:
  348. continue
  349. # Skip loading extra bias for GPTQ models.
  350. if name.endswith(".bias") and name not in params_dict:
  351. continue
  352. param = params_dict[name]
  353. weight_loader = getattr(param, "weight_loader",
  354. default_weight_loader)
  355. weight_loader(param, loaded_weight)
  356. loaded_params.add(name)
  357. unloaded_params = params_dict.keys() - loaded_params
  358. if unloaded_params:
  359. raise RuntimeError(
  360. "Some weights are not initialized from checkpoints: "
  361. f"{unloaded_params}")