gemma2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  5. #
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. from typing import Iterable, List, Optional, Set, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import Gemma2Config
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig, LoRAConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.common.utils import progress_bar
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import GeluAndMul
  29. from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
  30. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from .interfaces import SupportsLoRA
  42. class Gemma2MLP(nn.Module):
  43. def __init__(
  44. self,
  45. hidden_size: int,
  46. intermediate_size: int,
  47. hidden_act: str,
  48. hidden_activation: str,
  49. quant_config: Optional[QuantizationConfig] = None,
  50. ) -> None:
  51. super().__init__()
  52. self.gate_up_proj = MergedColumnParallelLinear(
  53. hidden_size, [intermediate_size] * 2,
  54. bias=False,
  55. quant_config=quant_config)
  56. self.down_proj = RowParallelLinear(intermediate_size,
  57. hidden_size,
  58. bias=False,
  59. quant_config=quant_config)
  60. if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
  61. raise ValueError(
  62. "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
  63. "function. Please set `hidden_act` and `hidden_activation` to "
  64. "`gelu_pytorch_tanh`.")
  65. self.act_fn = GeluAndMul(approximate="tanh")
  66. def forward(self, x: torch.Tensor) -> torch.Tensor:
  67. gate_up, _ = self.gate_up_proj(x)
  68. x = self.act_fn(gate_up)
  69. x, _ = self.down_proj(x)
  70. return x
  71. class Gemma2Attention(nn.Module):
  72. def __init__(self,
  73. layer_idx: int,
  74. config: Gemma2Config,
  75. hidden_size: int,
  76. num_heads: int,
  77. num_kv_heads: int,
  78. head_dim: int,
  79. max_position_embeddings: int,
  80. rope_theta: float,
  81. cache_config: Optional[CacheConfig] = None,
  82. quant_config: Optional[QuantizationConfig] = None,
  83. attn_logits_soft_cap: Optional[float] = None) -> None:
  84. super().__init__()
  85. self.layer_idx = layer_idx
  86. self.config = config
  87. self.hidden_size = hidden_size
  88. tp_size = get_tensor_model_parallel_world_size()
  89. self.total_num_heads = num_heads
  90. assert self.total_num_heads % tp_size == 0
  91. self.num_heads = self.total_num_heads // tp_size
  92. self.total_num_kv_heads = num_kv_heads
  93. if self.total_num_kv_heads >= tp_size:
  94. # Number of KV heads is greater than TP size, so we partition
  95. # the KV heads across multiple tensor parallel GPUs.
  96. assert self.total_num_kv_heads % tp_size == 0
  97. else:
  98. # Number of KV heads is less than TP size, so we replicate
  99. # the KV heads across multiple tensor parallel GPUs.
  100. assert tp_size % self.total_num_kv_heads == 0
  101. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  102. self.head_dim = head_dim
  103. self.q_size = self.num_heads * self.head_dim
  104. self.kv_size = self.num_kv_heads * self.head_dim
  105. self.scaling = config.query_pre_attn_scalar**-0.5
  106. self.rope_theta = rope_theta
  107. self.qkv_proj = QKVParallelLinear(
  108. hidden_size,
  109. self.head_dim,
  110. self.total_num_heads,
  111. self.total_num_kv_heads,
  112. bias=config.attention_bias,
  113. quant_config=quant_config,
  114. )
  115. self.o_proj = RowParallelLinear(
  116. self.total_num_heads * self.head_dim,
  117. hidden_size,
  118. bias=config.attention_bias,
  119. quant_config=quant_config,
  120. )
  121. # TODO: Use the `get_rope` interface.
  122. self.rotary_emb = GemmaRotaryEmbedding(
  123. self.head_dim,
  124. self.head_dim,
  125. max_position_embeddings,
  126. base=self.rope_theta,
  127. is_neox_style=True,
  128. dtype=torch.get_default_dtype(),
  129. )
  130. # FIXME: While Gemma 2 uses sliding window attention for every
  131. # odd layer, Aphrodite currently ignores it and uses global attention
  132. # for all layers.
  133. use_sliding_window = (layer_idx % 2 == 1
  134. and config.sliding_window is not None)
  135. del use_sliding_window # Unused.
  136. self.attn = Attention(self.num_heads,
  137. self.head_dim,
  138. self.scaling,
  139. num_kv_heads=self.num_kv_heads,
  140. cache_config=cache_config,
  141. quant_config=quant_config,
  142. logits_soft_cap=attn_logits_soft_cap)
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> torch.Tensor:
  150. qkv, _ = self.qkv_proj(hidden_states)
  151. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  152. q, k = self.rotary_emb(positions, q, k)
  153. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  154. output, _ = self.o_proj(attn_output)
  155. return output
  156. class Gemma2DecoderLayer(nn.Module):
  157. def __init__(
  158. self,
  159. layer_idx: int,
  160. config: Gemma2Config,
  161. cache_config: Optional[CacheConfig] = None,
  162. quant_config: Optional[QuantizationConfig] = None,
  163. ) -> None:
  164. super().__init__()
  165. self.hidden_size = config.hidden_size
  166. self.self_attn = Gemma2Attention(
  167. layer_idx=layer_idx,
  168. config=config,
  169. hidden_size=self.hidden_size,
  170. num_heads=config.num_attention_heads,
  171. num_kv_heads=config.num_key_value_heads,
  172. head_dim=config.head_dim,
  173. max_position_embeddings=config.max_position_embeddings,
  174. rope_theta=config.rope_theta,
  175. cache_config=cache_config,
  176. quant_config=quant_config,
  177. attn_logits_soft_cap=config.attn_logit_softcapping,
  178. )
  179. self.hidden_size = config.hidden_size
  180. self.mlp = Gemma2MLP(
  181. hidden_size=self.hidden_size,
  182. intermediate_size=config.intermediate_size,
  183. hidden_act=config.hidden_act,
  184. hidden_activation=config.hidden_activation,
  185. quant_config=quant_config,
  186. )
  187. self.input_layernorm = GemmaRMSNorm(config.hidden_size,
  188. eps=config.rms_norm_eps)
  189. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
  190. eps=config.rms_norm_eps)
  191. self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  192. eps=config.rms_norm_eps)
  193. self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
  194. eps=config.rms_norm_eps)
  195. def forward(
  196. self,
  197. positions: torch.Tensor,
  198. hidden_states: torch.Tensor,
  199. kv_cache: torch.Tensor,
  200. attn_metadata: AttentionMetadata,
  201. residual: Optional[torch.Tensor],
  202. ) -> Tuple[torch.Tensor, torch.Tensor]:
  203. if residual is None:
  204. residual = hidden_states
  205. hidden_states = self.input_layernorm(hidden_states)
  206. else:
  207. hidden_states, residual = self.input_layernorm(
  208. hidden_states, residual)
  209. hidden_states = self.self_attn(
  210. positions=positions,
  211. hidden_states=hidden_states,
  212. kv_cache=kv_cache,
  213. attn_metadata=attn_metadata,
  214. )
  215. hidden_states = self.post_attention_layernorm(hidden_states)
  216. hidden_states, residual = self.pre_feedforward_layernorm(
  217. hidden_states, residual)
  218. hidden_states = self.mlp(hidden_states)
  219. hidden_states = self.post_feedforward_layernorm(hidden_states)
  220. return hidden_states, residual
  221. class Gemma2Model(nn.Module):
  222. def __init__(
  223. self,
  224. config: Gemma2Config,
  225. cache_config: Optional[CacheConfig] = None,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. ) -> None:
  228. super().__init__()
  229. self.config = config
  230. self.embed_tokens = VocabParallelEmbedding(
  231. config.vocab_size,
  232. config.hidden_size,
  233. )
  234. self.layers = nn.ModuleList([
  235. Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
  236. for layer_idx in range(config.num_hidden_layers)
  237. ])
  238. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  239. # Normalize the embedding by sqrt(hidden_size)
  240. # The normalizer's data type should be downcasted to the model's
  241. # data type such as bfloat16, not float32.
  242. # See https://github.com/huggingface/transformers/pull/29402
  243. normalizer = self.config.hidden_size**0.5
  244. self.register_buffer("normalizer", torch.tensor(normalizer))
  245. def forward(
  246. self,
  247. input_ids: torch.Tensor,
  248. positions: torch.Tensor,
  249. kv_caches: List[torch.Tensor],
  250. attn_metadata: AttentionMetadata,
  251. ) -> torch.Tensor:
  252. hidden_states = self.embed_tokens(input_ids)
  253. hidden_states *= self.normalizer
  254. residual = None
  255. for i in range(len(self.layers)):
  256. layer = self.layers[i]
  257. hidden_states, residual = layer(
  258. positions,
  259. hidden_states,
  260. kv_caches[i],
  261. attn_metadata,
  262. residual,
  263. )
  264. hidden_states, _ = self.norm(hidden_states, residual)
  265. return hidden_states
  266. class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
  267. packed_modules_mapping = {
  268. "qkv_proj": [
  269. "q_proj",
  270. "k_proj",
  271. "v_proj",
  272. ],
  273. "gate_up_proj": [
  274. "gate_proj",
  275. "up_proj",
  276. ],
  277. }
  278. # LoRA specific attributes
  279. supported_lora_modules = [
  280. "qkv_proj",
  281. "o_proj",
  282. "gate_up_proj",
  283. "down_proj",
  284. ]
  285. # Gemma does not apply LoRA to the embedding layer.
  286. embedding_modules = {}
  287. embedding_padding_modules = []
  288. def __init__(
  289. self,
  290. config: Gemma2Config,
  291. cache_config: Optional[CacheConfig] = None,
  292. quant_config: Optional[QuantizationConfig] = None,
  293. lora_config: Optional[LoRAConfig] = None,
  294. ) -> None:
  295. del lora_config # Unused.
  296. super().__init__()
  297. self.config = config
  298. self.quant_config = quant_config
  299. self.model = Gemma2Model(config, cache_config, quant_config)
  300. self.logits_processor = LogitsProcessor(
  301. config.vocab_size, soft_cap=config.final_logit_softcapping)
  302. self.sampler = Sampler()
  303. def forward(
  304. self,
  305. input_ids: torch.Tensor,
  306. positions: torch.Tensor,
  307. kv_caches: List[torch.Tensor],
  308. attn_metadata: AttentionMetadata,
  309. intermediate_tensors: Optional[IntermediateTensors] = None,
  310. ) -> torch.Tensor:
  311. hidden_states = self.model(input_ids, positions, kv_caches,
  312. attn_metadata)
  313. return hidden_states
  314. def compute_logits(
  315. self,
  316. hidden_states: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[torch.Tensor]:
  319. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  320. sampling_metadata)
  321. return logits
  322. def sample(
  323. self,
  324. logits: torch.Tensor,
  325. sampling_metadata: SamplingMetadata,
  326. ) -> Optional[SamplerOutput]:
  327. next_tokens = self.sampler(logits, sampling_metadata)
  328. return next_tokens
  329. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  330. stacked_params_mapping = [
  331. # (param_name, shard_name, shard_id)
  332. ("qkv_proj", "q_proj", "q"),
  333. ("qkv_proj", "k_proj", "k"),
  334. ("qkv_proj", "v_proj", "v"),
  335. ("gate_up_proj", "gate_proj", 0),
  336. ("gate_up_proj", "up_proj", 1),
  337. ]
  338. params_dict = dict(self.named_parameters())
  339. loaded_params: Set[str] = set()
  340. weights_list = list(weights)
  341. for name, loaded_weight in progress_bar(weights_list,
  342. desc="Loading modules..."):
  343. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  344. if shard_name not in name:
  345. continue
  346. name = name.replace(shard_name, param_name)
  347. # Skip loading extra bias for GPTQ models.
  348. if name.endswith(".bias") and name not in params_dict:
  349. continue
  350. param = params_dict[name]
  351. weight_loader = param.weight_loader
  352. weight_loader(param, loaded_weight, shard_id)
  353. break
  354. else:
  355. # lm_head is not used in Aphrodite as it is tied with
  356. # embed_token.
  357. # To prevent errors, skip loading lm_head.weight.
  358. if "lm_head.weight" in name:
  359. continue
  360. # Skip loading extra bias for GPTQ models.
  361. if name.endswith(".bias") and name not in params_dict:
  362. continue
  363. param = params_dict[name]
  364. weight_loader = getattr(param, "weight_loader",
  365. default_weight_loader)
  366. weight_loader(param, loaded_weight)
  367. loaded_params.add(name)
  368. unloaded_params = params_dict.keys() - loaded_params
  369. if unloaded_params:
  370. logger.warning(
  371. "Some weights are not initialized from checkpoints: "
  372. f"{unloaded_params}")