gemma.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # coding=utf-8
  2. # Copyright 2023 The vLLM team.
  3. # Copyright (c) Google Inc.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Inference-only Gemma model compatible with HuggingFace weights."""
  17. from functools import lru_cache
  18. from typing import Iterable, List, Optional, Set, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import GemmaConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig, LoRAConfig
  25. from aphrodite.common.sequence import IntermediateTensors
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import GeluAndMul
  28. from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
  29. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
  34. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from .interfaces import SupportsLoRA
  41. @lru_cache(maxsize=None)
  42. def _get_gemma_act_fn(
  43. hidden_act: Optional[str],
  44. hidden_activation: Optional[str],
  45. ) -> nn.Module:
  46. if hidden_activation is None:
  47. if hidden_act is not None:
  48. logger.warning(
  49. "Gemma's activation function was incorrectly set to exact GeLU "
  50. "in the config JSON file when it was initially released. "
  51. "Changing the activation function to approximate GeLU "
  52. "(`gelu_pytorch_tanh`). If you want to use the legacy "
  53. f"`{hidden_act}`, edit the config JSON to set "
  54. f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
  55. "See https://github.com/huggingface/transformers/pull/29402 "
  56. "for more details.")
  57. return GeluAndMul(approximate="tanh")
  58. elif hidden_activation == "gelu_pytorch_tanh":
  59. return GeluAndMul(approximate="tanh")
  60. elif hidden_activation == "gelu":
  61. return GeluAndMul(approximate="none")
  62. else:
  63. raise ValueError(f"Activation function {hidden_act} is not "
  64. "supported for Gemma models.")
  65. class GemmaMLP(nn.Module):
  66. def __init__(
  67. self,
  68. hidden_size: int,
  69. intermediate_size: int,
  70. hidden_act: Optional[str] = None,
  71. hidden_activation: Optional[str] = None,
  72. quant_config: Optional[QuantizationConfig] = None,
  73. ) -> None:
  74. super().__init__()
  75. self.gate_up_proj = MergedColumnParallelLinear(
  76. hidden_size, [intermediate_size] * 2,
  77. bias=False,
  78. quant_config=quant_config)
  79. self.down_proj = RowParallelLinear(intermediate_size,
  80. hidden_size,
  81. bias=False,
  82. quant_config=quant_config)
  83. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  84. def forward(self, x):
  85. gate_up, _ = self.gate_up_proj(x)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.down_proj(x)
  88. return x
  89. class GemmaAttention(nn.Module):
  90. def __init__(self,
  91. hidden_size: int,
  92. num_heads: int,
  93. num_kv_heads: int,
  94. head_dim: int,
  95. max_position_embeddings: int = 8192,
  96. rope_theta: float = 10000,
  97. cache_config: Optional[CacheConfig] = None,
  98. quant_config: Optional[QuantizationConfig] = None) -> None:
  99. super().__init__()
  100. self.hidden_size = hidden_size
  101. tp_size = get_tensor_model_parallel_world_size()
  102. self.total_num_heads = num_heads
  103. assert self.total_num_heads % tp_size == 0
  104. self.num_heads = self.total_num_heads // tp_size
  105. self.total_num_kv_heads = num_kv_heads
  106. if self.total_num_kv_heads >= tp_size:
  107. # Number of KV heads is greater than TP size, so we partition
  108. # the KV heads across multiple tensor parallel GPUs.
  109. assert self.total_num_kv_heads % tp_size == 0
  110. else:
  111. # Number of KV heads is less than TP size, so we replicate
  112. # the KV heads across multiple tensor parallel GPUs.
  113. assert tp_size % self.total_num_kv_heads == 0
  114. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  115. self.head_dim = head_dim
  116. self.q_size = self.num_heads * self.head_dim
  117. self.kv_size = self.num_kv_heads * self.head_dim
  118. self.scaling = self.head_dim**-0.5
  119. self.rope_theta = rope_theta
  120. self.qkv_proj = QKVParallelLinear(
  121. hidden_size,
  122. self.head_dim,
  123. self.total_num_heads,
  124. self.total_num_kv_heads,
  125. bias=False,
  126. quant_config=quant_config,
  127. )
  128. self.o_proj = RowParallelLinear(
  129. self.total_num_heads * self.head_dim,
  130. hidden_size,
  131. bias=False,
  132. quant_config=quant_config,
  133. )
  134. # TODO: Use the `get_rope` interface.
  135. self.rotary_emb = GemmaRotaryEmbedding(
  136. self.head_dim,
  137. rotary_dim=self.head_dim,
  138. max_position_embeddings=max_position_embeddings,
  139. base=self.rope_theta,
  140. is_neox_style=True,
  141. dtype=torch.get_default_dtype(),
  142. )
  143. self.attn = Attention(self.num_heads,
  144. self.head_dim,
  145. self.scaling,
  146. num_kv_heads=self.num_kv_heads,
  147. cache_config=cache_config,
  148. quant_config=quant_config)
  149. def forward(
  150. self,
  151. positions: torch.Tensor,
  152. hidden_states: torch.Tensor,
  153. kv_cache: torch.Tensor,
  154. attn_metadata: AttentionMetadata,
  155. ) -> torch.Tensor:
  156. qkv, _ = self.qkv_proj(hidden_states)
  157. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  158. q, k = self.rotary_emb(positions, q, k)
  159. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  160. output, _ = self.o_proj(attn_output)
  161. return output
  162. class GemmaDecoderLayer(nn.Module):
  163. def __init__(
  164. self,
  165. config: GemmaConfig,
  166. cache_config: Optional[CacheConfig] = None,
  167. quant_config: Optional[QuantizationConfig] = None,
  168. ) -> None:
  169. super().__init__()
  170. self.hidden_size = config.hidden_size
  171. self.self_attn = GemmaAttention(
  172. hidden_size=self.hidden_size,
  173. num_heads=config.num_attention_heads,
  174. num_kv_heads=config.num_key_value_heads,
  175. head_dim=config.head_dim,
  176. max_position_embeddings=config.max_position_embeddings,
  177. rope_theta=config.rope_theta,
  178. cache_config=cache_config,
  179. quant_config=quant_config,
  180. )
  181. self.mlp = GemmaMLP(
  182. hidden_size=self.hidden_size,
  183. intermediate_size=config.intermediate_size,
  184. hidden_act=config.hidden_act,
  185. hidden_activation=getattr(config, "hidden_activation", None),
  186. quant_config=quant_config,
  187. )
  188. self.input_layernorm = GemmaRMSNorm(config.hidden_size,
  189. eps=config.rms_norm_eps)
  190. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
  191. eps=config.rms_norm_eps)
  192. def forward(
  193. self,
  194. positions: torch.Tensor,
  195. hidden_states: torch.Tensor,
  196. kv_cache: torch.Tensor,
  197. attn_metadata: AttentionMetadata,
  198. residual: Optional[torch.Tensor],
  199. ) -> Tuple[torch.Tensor, torch.Tensor]:
  200. # Self Attention
  201. if residual is None:
  202. residual = hidden_states
  203. hidden_states = self.input_layernorm(hidden_states)
  204. else:
  205. hidden_states, residual = self.input_layernorm(
  206. hidden_states, residual)
  207. hidden_states = self.self_attn(
  208. positions=positions,
  209. hidden_states=hidden_states,
  210. kv_cache=kv_cache,
  211. attn_metadata=attn_metadata,
  212. )
  213. # Fully Connected
  214. hidden_states, residual = self.post_attention_layernorm(
  215. hidden_states, residual)
  216. hidden_states = self.mlp(hidden_states)
  217. return hidden_states, residual
  218. class GemmaModel(nn.Module):
  219. def __init__(
  220. self,
  221. config: GemmaConfig,
  222. cache_config: Optional[CacheConfig] = None,
  223. quant_config: Optional[QuantizationConfig] = None,
  224. ) -> None:
  225. super().__init__()
  226. self.config = config
  227. self.embed_tokens = VocabParallelEmbedding(
  228. config.vocab_size,
  229. config.hidden_size,
  230. )
  231. self.layers = nn.ModuleList([
  232. GemmaDecoderLayer(config, cache_config, quant_config)
  233. for _ in range(config.num_hidden_layers)
  234. ])
  235. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  236. # Normalize the embedding by sqrt(hidden_size)
  237. # The normalizer's data type should be downcasted to the model's
  238. # data type such as bfloat16, not float32.
  239. # See https://github.com/huggingface/transformers/pull/29402
  240. normalizer = self.config.hidden_size**0.5
  241. self.register_buffer("normalizer", torch.tensor(normalizer))
  242. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  243. return self.embed_tokens(input_ids)
  244. def forward(
  245. self,
  246. input_ids: torch.Tensor,
  247. positions: torch.Tensor,
  248. kv_caches: List[torch.Tensor],
  249. attn_metadata: AttentionMetadata,
  250. intermediate_tensors: Optional[IntermediateTensors] = None,
  251. inputs_embeds: Optional[torch.Tensor] = None,
  252. ) -> torch.Tensor:
  253. if inputs_embeds is not None:
  254. hidden_states = inputs_embeds
  255. else:
  256. hidden_states = self.get_input_embeddings(input_ids)
  257. hidden_states *= self.normalizer
  258. residual = None
  259. for i in range(len(self.layers)):
  260. layer = self.layers[i]
  261. hidden_states, residual = layer(
  262. positions,
  263. hidden_states,
  264. kv_caches[i],
  265. attn_metadata,
  266. residual,
  267. )
  268. hidden_states, _ = self.norm(hidden_states, residual)
  269. return hidden_states
  270. class GemmaForCausalLM(nn.Module, SupportsLoRA):
  271. packed_modules_mapping = {
  272. "qkv_proj": [
  273. "q_proj",
  274. "k_proj",
  275. "v_proj",
  276. ],
  277. "gate_up_proj": [
  278. "gate_proj",
  279. "up_proj",
  280. ],
  281. }
  282. # LoRA specific attributes
  283. supported_lora_modules = [
  284. "qkv_proj",
  285. "o_proj",
  286. "gate_up_proj",
  287. "down_proj",
  288. ]
  289. # Gemma does not apply LoRA to the embedding layer.
  290. embedding_modules = {}
  291. embedding_padding_modules = []
  292. def __init__(
  293. self,
  294. config: GemmaConfig,
  295. cache_config: Optional[CacheConfig] = None,
  296. quant_config: Optional[QuantizationConfig] = None,
  297. lora_config: Optional[LoRAConfig] = None,
  298. ) -> None:
  299. super().__init__()
  300. self.config = config
  301. self.lora_config = lora_config
  302. self.quant_config = quant_config
  303. self.model = GemmaModel(config, cache_config, quant_config)
  304. self.logits_processor = LogitsProcessor(config.vocab_size)
  305. self.sampler = Sampler()
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[torch.Tensor],
  311. attn_metadata: AttentionMetadata,
  312. intermediate_tensors: Optional[IntermediateTensors] = None,
  313. ) -> torch.Tensor:
  314. hidden_states = self.model(input_ids, positions, kv_caches,
  315. attn_metadata)
  316. return hidden_states
  317. def compute_logits(
  318. self,
  319. hidden_states: torch.Tensor,
  320. sampling_metadata: SamplingMetadata,
  321. ) -> Optional[torch.Tensor]:
  322. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  323. sampling_metadata)
  324. return logits
  325. def sample(
  326. self,
  327. logits: torch.Tensor,
  328. sampling_metadata: SamplingMetadata,
  329. ) -> Optional[SamplerOutput]:
  330. next_tokens = self.sampler(logits, sampling_metadata)
  331. return next_tokens
  332. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  333. stacked_params_mapping = [
  334. # (param_name, shard_name, shard_id)
  335. ("qkv_proj", "q_proj", "q"),
  336. ("qkv_proj", "k_proj", "k"),
  337. ("qkv_proj", "v_proj", "v"),
  338. ("gate_up_proj", "gate_proj", 0),
  339. ("gate_up_proj", "up_proj", 1),
  340. ]
  341. params_dict = dict(self.named_parameters())
  342. loaded_params: Set[str] = set()
  343. for name, loaded_weight in weights:
  344. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  345. if shard_name not in name:
  346. continue
  347. name = name.replace(shard_name, param_name)
  348. # Skip loading extra bias for GPTQ models.
  349. if name.endswith(".bias") and name not in params_dict:
  350. continue
  351. param = params_dict[name]
  352. weight_loader = param.weight_loader
  353. weight_loader(param, loaded_weight, shard_id)
  354. break
  355. else:
  356. # lm_head is not used in aphrodite as it is tied with
  357. # embed_token. To prevent errors, skip loading lm_head.weight.
  358. if "lm_head.weight" in name:
  359. continue
  360. # Skip loading extra bias for GPTQ models.
  361. if name.endswith(".bias") and name not in params_dict:
  362. continue
  363. param = params_dict[name]
  364. weight_loader = getattr(param, "weight_loader",
  365. default_weight_loader)
  366. weight_loader(param, loaded_weight)
  367. loaded_params.add(name)
  368. unloaded_params = params_dict.keys() - loaded_params
  369. if unloaded_params:
  370. logger.warning(
  371. "Some weights are not initialized from checkpoints: "
  372. f"{unloaded_params}")