gemma2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  5. #
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. from typing import Iterable, List, Optional, Set, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import Gemma2Config
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig, LoRAConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import GeluAndMul
  28. from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
  29. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from .interfaces import SupportsLoRA
  41. class Gemma2MLP(nn.Module):
  42. def __init__(
  43. self,
  44. hidden_size: int,
  45. intermediate_size: int,
  46. hidden_act: str,
  47. hidden_activation: str,
  48. quant_config: Optional[QuantizationConfig] = None,
  49. ) -> None:
  50. super().__init__()
  51. self.gate_up_proj = MergedColumnParallelLinear(
  52. hidden_size, [intermediate_size] * 2,
  53. bias=False,
  54. quant_config=quant_config)
  55. self.down_proj = RowParallelLinear(intermediate_size,
  56. hidden_size,
  57. bias=False,
  58. quant_config=quant_config)
  59. if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
  60. raise ValueError(
  61. "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
  62. "function. Please set `hidden_act` and `hidden_activation` to "
  63. "`gelu_pytorch_tanh`.")
  64. self.act_fn = GeluAndMul(approximate="tanh")
  65. def forward(self, x: torch.Tensor) -> torch.Tensor:
  66. gate_up, _ = self.gate_up_proj(x)
  67. x = self.act_fn(gate_up)
  68. x, _ = self.down_proj(x)
  69. return x
  70. class Gemma2Attention(nn.Module):
  71. def __init__(self,
  72. layer_idx: int,
  73. config: Gemma2Config,
  74. hidden_size: int,
  75. num_heads: int,
  76. num_kv_heads: int,
  77. head_dim: int,
  78. max_position_embeddings: int,
  79. rope_theta: float,
  80. cache_config: Optional[CacheConfig] = None,
  81. quant_config: Optional[QuantizationConfig] = None,
  82. attn_logits_soft_cap: Optional[float] = None) -> None:
  83. super().__init__()
  84. self.layer_idx = layer_idx
  85. self.config = config
  86. self.hidden_size = hidden_size
  87. tp_size = get_tensor_model_parallel_world_size()
  88. self.total_num_heads = num_heads
  89. assert self.total_num_heads % tp_size == 0
  90. self.num_heads = self.total_num_heads // tp_size
  91. self.total_num_kv_heads = num_kv_heads
  92. if self.total_num_kv_heads >= tp_size:
  93. # Number of KV heads is greater than TP size, so we partition
  94. # the KV heads across multiple tensor parallel GPUs.
  95. assert self.total_num_kv_heads % tp_size == 0
  96. else:
  97. # Number of KV heads is less than TP size, so we replicate
  98. # the KV heads across multiple tensor parallel GPUs.
  99. assert tp_size % self.total_num_kv_heads == 0
  100. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  101. self.head_dim = head_dim
  102. self.q_size = self.num_heads * self.head_dim
  103. self.kv_size = self.num_kv_heads * self.head_dim
  104. self.scaling = config.query_pre_attn_scalar**-0.5
  105. self.rope_theta = rope_theta
  106. self.qkv_proj = QKVParallelLinear(
  107. hidden_size,
  108. self.head_dim,
  109. self.total_num_heads,
  110. self.total_num_kv_heads,
  111. bias=config.attention_bias,
  112. quant_config=quant_config,
  113. )
  114. self.o_proj = RowParallelLinear(
  115. self.total_num_heads * self.head_dim,
  116. hidden_size,
  117. bias=config.attention_bias,
  118. quant_config=quant_config,
  119. )
  120. # TODO: Use the `get_rope` interface.
  121. self.rotary_emb = GemmaRotaryEmbedding(
  122. self.head_dim,
  123. self.head_dim,
  124. max_position_embeddings,
  125. base=self.rope_theta,
  126. is_neox_style=True,
  127. dtype=torch.get_default_dtype(),
  128. )
  129. # FIXME: While Gemma 2 uses sliding window attention for every
  130. # odd layer, Aphrodite currently ignores it and uses global attention
  131. # for all layers.
  132. use_sliding_window = (layer_idx % 2 == 1
  133. and config.sliding_window is not None)
  134. del use_sliding_window # Unused.
  135. self.attn = Attention(self.num_heads,
  136. self.head_dim,
  137. self.scaling,
  138. num_kv_heads=self.num_kv_heads,
  139. cache_config=cache_config,
  140. quant_config=quant_config,
  141. logits_soft_cap=attn_logits_soft_cap)
  142. def forward(
  143. self,
  144. positions: torch.Tensor,
  145. hidden_states: torch.Tensor,
  146. kv_cache: torch.Tensor,
  147. attn_metadata: AttentionMetadata,
  148. ) -> torch.Tensor:
  149. qkv, _ = self.qkv_proj(hidden_states)
  150. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  151. q, k = self.rotary_emb(positions, q, k)
  152. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  153. output, _ = self.o_proj(attn_output)
  154. return output
  155. class Gemma2DecoderLayer(nn.Module):
  156. def __init__(
  157. self,
  158. layer_idx: int,
  159. config: Gemma2Config,
  160. cache_config: Optional[CacheConfig] = None,
  161. quant_config: Optional[QuantizationConfig] = None,
  162. ) -> None:
  163. super().__init__()
  164. self.hidden_size = config.hidden_size
  165. self.self_attn = Gemma2Attention(
  166. layer_idx=layer_idx,
  167. config=config,
  168. hidden_size=self.hidden_size,
  169. num_heads=config.num_attention_heads,
  170. num_kv_heads=config.num_key_value_heads,
  171. head_dim=config.head_dim,
  172. max_position_embeddings=config.max_position_embeddings,
  173. rope_theta=config.rope_theta,
  174. cache_config=cache_config,
  175. quant_config=quant_config,
  176. attn_logits_soft_cap=config.attn_logit_softcapping,
  177. )
  178. self.hidden_size = config.hidden_size
  179. self.mlp = Gemma2MLP(
  180. hidden_size=self.hidden_size,
  181. intermediate_size=config.intermediate_size,
  182. hidden_act=config.hidden_act,
  183. hidden_activation=config.hidden_activation,
  184. quant_config=quant_config,
  185. )
  186. self.input_layernorm = GemmaRMSNorm(config.hidden_size,
  187. eps=config.rms_norm_eps)
  188. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
  189. eps=config.rms_norm_eps)
  190. self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  191. eps=config.rms_norm_eps)
  192. self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  193. eps=config.rms_norm_eps)
  194. def forward(
  195. self,
  196. positions: torch.Tensor,
  197. hidden_states: torch.Tensor,
  198. kv_cache: torch.Tensor,
  199. attn_metadata: AttentionMetadata,
  200. residual: Optional[torch.Tensor],
  201. ) -> Tuple[torch.Tensor, torch.Tensor]:
  202. if residual is None:
  203. residual = hidden_states
  204. hidden_states = self.input_layernorm(hidden_states)
  205. else:
  206. hidden_states, residual = self.input_layernorm(
  207. hidden_states, residual)
  208. hidden_states = self.self_attn(
  209. positions=positions,
  210. hidden_states=hidden_states,
  211. kv_cache=kv_cache,
  212. attn_metadata=attn_metadata,
  213. )
  214. hidden_states = self.post_attention_layernorm(hidden_states)
  215. hidden_states, residual = self.pre_feedforward_layernorm(
  216. hidden_states, residual)
  217. hidden_states = self.mlp(hidden_states)
  218. hidden_states = self.post_feedforward_layernorm(hidden_states)
  219. return hidden_states, residual
  220. class Gemma2Model(nn.Module):
  221. def __init__(
  222. self,
  223. config: Gemma2Config,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None,
  226. ) -> None:
  227. super().__init__()
  228. self.config = config
  229. self.embed_tokens = VocabParallelEmbedding(
  230. config.vocab_size,
  231. config.hidden_size,
  232. )
  233. self.layers = nn.ModuleList([
  234. Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
  235. for layer_idx in range(config.num_hidden_layers)
  236. ])
  237. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  238. # Normalize the embedding by sqrt(hidden_size)
  239. # The normalizer's data type should be downcasted to the model's
  240. # data type such as bfloat16, not float32.
  241. # See https://github.com/huggingface/transformers/pull/29402
  242. normalizer = self.config.hidden_size**0.5
  243. self.register_buffer("normalizer", torch.tensor(normalizer))
  244. def forward(
  245. self,
  246. input_ids: torch.Tensor,
  247. positions: torch.Tensor,
  248. kv_caches: List[torch.Tensor],
  249. attn_metadata: AttentionMetadata,
  250. ) -> torch.Tensor:
  251. hidden_states = self.embed_tokens(input_ids)
  252. hidden_states *= self.normalizer
  253. residual = None
  254. for i in range(len(self.layers)):
  255. layer = self.layers[i]
  256. hidden_states, residual = layer(
  257. positions,
  258. hidden_states,
  259. kv_caches[i],
  260. attn_metadata,
  261. residual,
  262. )
  263. hidden_states, _ = self.norm(hidden_states, residual)
  264. return hidden_states
  265. class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
  266. packed_modules_mapping = {
  267. "qkv_proj": [
  268. "q_proj",
  269. "k_proj",
  270. "v_proj",
  271. ],
  272. "gate_up_proj": [
  273. "gate_proj",
  274. "up_proj",
  275. ],
  276. }
  277. # LoRA specific attributes
  278. supported_lora_modules = [
  279. "qkv_proj",
  280. "o_proj",
  281. "gate_up_proj",
  282. "down_proj",
  283. ]
  284. # Gemma does not apply LoRA to the embedding layer.
  285. embedding_modules = {}
  286. embedding_padding_modules = []
  287. def __init__(
  288. self,
  289. config: Gemma2Config,
  290. cache_config: Optional[CacheConfig] = None,
  291. quant_config: Optional[QuantizationConfig] = None,
  292. lora_config: Optional[LoRAConfig] = None,
  293. ) -> None:
  294. del lora_config # Unused.
  295. super().__init__()
  296. self.config = config
  297. self.quant_config = quant_config
  298. self.model = Gemma2Model(config, cache_config, quant_config)
  299. self.logits_processor = LogitsProcessor(
  300. config.vocab_size, soft_cap=config.final_logit_softcapping)
  301. self.sampler = Sampler()
  302. def forward(
  303. self,
  304. input_ids: torch.Tensor,
  305. positions: torch.Tensor,
  306. kv_caches: List[torch.Tensor],
  307. attn_metadata: AttentionMetadata,
  308. intermediate_tensors: Optional[IntermediateTensors] = None,
  309. ) -> torch.Tensor:
  310. hidden_states = self.model(input_ids, positions, kv_caches,
  311. attn_metadata)
  312. return hidden_states
  313. def compute_logits(self, hidden_states: torch.Tensor,
  314. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  315. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  316. sampling_metadata)
  317. return logits
  318. def sample(
  319. self,
  320. logits: torch.Tensor,
  321. sampling_metadata: SamplingMetadata,
  322. ) -> Optional[SamplerOutput]:
  323. next_tokens = self.sampler(logits, sampling_metadata)
  324. return next_tokens
  325. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  326. stacked_params_mapping = [
  327. # (param_name, shard_name, shard_id)
  328. ("qkv_proj", "q_proj", "q"),
  329. ("qkv_proj", "k_proj", "k"),
  330. ("qkv_proj", "v_proj", "v"),
  331. ("gate_up_proj", "gate_proj", 0),
  332. ("gate_up_proj", "up_proj", 1),
  333. ]
  334. params_dict = dict(self.named_parameters())
  335. loaded_params: Set[str] = set()
  336. for name, loaded_weight in weights:
  337. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  338. if shard_name not in name:
  339. continue
  340. name = name.replace(shard_name, param_name)
  341. # Skip loading extra bias for GPTQ models.
  342. if name.endswith(".bias") and name not in params_dict:
  343. continue
  344. param = params_dict[name]
  345. weight_loader = param.weight_loader
  346. weight_loader(param, loaded_weight, shard_id)
  347. break
  348. else:
  349. # lm_head is not used in Aphrodite as it is tied with
  350. # embed_token.
  351. # To prevent errors, skip loading lm_head.weight.
  352. if "lm_head.weight" in name:
  353. continue
  354. # Skip loading extra bias for GPTQ models.
  355. if name.endswith(".bias") and name not in params_dict:
  356. continue
  357. param = params_dict[name]
  358. weight_loader = getattr(param, "weight_loader",
  359. default_weight_loader)
  360. weight_loader(param, loaded_weight)
  361. loaded_params.add(name)
  362. unloaded_params = params_dict.keys() - loaded_params
  363. if unloaded_params:
  364. logger.warning(
  365. "Some weights are not initialized from checkpoints: "
  366. f"{unloaded_params}")