gemma.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # coding=utf-8
  2. # Copyright 2023 The vLLM team.
  3. # Copyright (c) Google Inc.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Inference-only Gemma model compatible with HuggingFace weights."""
  17. from functools import lru_cache
  18. from typing import Iterable, List, Optional, Set, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import GemmaConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig, LoRAConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.common.utils import progress_bar
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import GeluAndMul
  29. from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
  30. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.rotary_embedding import GemmaRotaryEmbedding
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from .interfaces import SupportsLoRA
  42. @lru_cache(maxsize=None)
  43. def _get_gemma_act_fn(
  44. hidden_act: Optional[str],
  45. hidden_activation: Optional[str],
  46. ) -> nn.Module:
  47. if hidden_activation is None:
  48. if hidden_act is not None:
  49. logger.warning(
  50. "Gemma's activation function was incorrectly set to exact GeLU "
  51. "in the config JSON file when it was initially released. "
  52. "Changing the activation function to approximate GeLU "
  53. "(`gelu_pytorch_tanh`). If you want to use the legacy "
  54. f"`{hidden_act}`, edit the config JSON to set "
  55. f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
  56. "See https://github.com/huggingface/transformers/pull/29402 "
  57. "for more details.")
  58. return GeluAndMul(approximate="tanh")
  59. elif hidden_activation == "gelu_pytorch_tanh":
  60. return GeluAndMul(approximate="tanh")
  61. elif hidden_activation == "gelu":
  62. return GeluAndMul(approximate="none")
  63. else:
  64. raise ValueError(f"Activation function {hidden_act} is not "
  65. "supported for Gemma models.")
  66. class GemmaMLP(nn.Module):
  67. def __init__(
  68. self,
  69. hidden_size: int,
  70. intermediate_size: int,
  71. hidden_act: Optional[str] = None,
  72. hidden_activation: Optional[str] = None,
  73. quant_config: Optional[QuantizationConfig] = None,
  74. ) -> None:
  75. super().__init__()
  76. self.gate_up_proj = MergedColumnParallelLinear(
  77. hidden_size, [intermediate_size] * 2,
  78. bias=False,
  79. quant_config=quant_config)
  80. self.down_proj = RowParallelLinear(intermediate_size,
  81. hidden_size,
  82. bias=False,
  83. quant_config=quant_config)
  84. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  85. def forward(self, x):
  86. gate_up, _ = self.gate_up_proj(x)
  87. x = self.act_fn(gate_up)
  88. x, _ = self.down_proj(x)
  89. return x
  90. class GemmaAttention(nn.Module):
  91. def __init__(self,
  92. hidden_size: int,
  93. num_heads: int,
  94. num_kv_heads: int,
  95. head_dim: int,
  96. max_position_embeddings: int = 8192,
  97. rope_theta: float = 10000,
  98. cache_config: Optional[CacheConfig] = None,
  99. quant_config: Optional[QuantizationConfig] = None) -> None:
  100. super().__init__()
  101. self.hidden_size = hidden_size
  102. tp_size = get_tensor_model_parallel_world_size()
  103. self.total_num_heads = num_heads
  104. assert self.total_num_heads % tp_size == 0
  105. self.num_heads = self.total_num_heads // tp_size
  106. self.total_num_kv_heads = num_kv_heads
  107. if self.total_num_kv_heads >= tp_size:
  108. # Number of KV heads is greater than TP size, so we partition
  109. # the KV heads across multiple tensor parallel GPUs.
  110. assert self.total_num_kv_heads % tp_size == 0
  111. else:
  112. # Number of KV heads is less than TP size, so we replicate
  113. # the KV heads across multiple tensor parallel GPUs.
  114. assert tp_size % self.total_num_kv_heads == 0
  115. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  116. self.head_dim = head_dim
  117. self.q_size = self.num_heads * self.head_dim
  118. self.kv_size = self.num_kv_heads * self.head_dim
  119. self.scaling = self.head_dim**-0.5
  120. self.rope_theta = rope_theta
  121. self.qkv_proj = QKVParallelLinear(
  122. hidden_size,
  123. self.head_dim,
  124. self.total_num_heads,
  125. self.total_num_kv_heads,
  126. bias=False,
  127. quant_config=quant_config,
  128. )
  129. self.o_proj = RowParallelLinear(
  130. self.total_num_heads * self.head_dim,
  131. hidden_size,
  132. bias=False,
  133. quant_config=quant_config,
  134. )
  135. # TODO: Use the `get_rope` interface.
  136. self.rotary_emb = GemmaRotaryEmbedding(
  137. self.head_dim,
  138. rotary_dim=self.head_dim,
  139. max_position_embeddings=max_position_embeddings,
  140. base=self.rope_theta,
  141. is_neox_style=True,
  142. dtype=torch.get_default_dtype(),
  143. )
  144. self.attn = Attention(self.num_heads,
  145. self.head_dim,
  146. self.scaling,
  147. num_kv_heads=self.num_kv_heads,
  148. cache_config=cache_config,
  149. quant_config=quant_config)
  150. def forward(
  151. self,
  152. positions: torch.Tensor,
  153. hidden_states: torch.Tensor,
  154. kv_cache: torch.Tensor,
  155. attn_metadata: AttentionMetadata,
  156. ) -> torch.Tensor:
  157. qkv, _ = self.qkv_proj(hidden_states)
  158. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  159. q, k = self.rotary_emb(positions, q, k)
  160. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  161. output, _ = self.o_proj(attn_output)
  162. return output
  163. class GemmaDecoderLayer(nn.Module):
  164. def __init__(
  165. self,
  166. config: GemmaConfig,
  167. cache_config: Optional[CacheConfig] = None,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. ) -> None:
  170. super().__init__()
  171. self.hidden_size = config.hidden_size
  172. self.self_attn = GemmaAttention(
  173. hidden_size=self.hidden_size,
  174. num_heads=config.num_attention_heads,
  175. num_kv_heads=config.num_key_value_heads,
  176. head_dim=config.head_dim,
  177. max_position_embeddings=config.max_position_embeddings,
  178. rope_theta=config.rope_theta,
  179. cache_config=cache_config,
  180. quant_config=quant_config,
  181. )
  182. self.mlp = GemmaMLP(
  183. hidden_size=self.hidden_size,
  184. intermediate_size=config.intermediate_size,
  185. hidden_act=config.hidden_act,
  186. hidden_activation=getattr(config, "hidden_activation", None),
  187. quant_config=quant_config,
  188. )
  189. self.input_layernorm = GemmaRMSNorm(config.hidden_size,
  190. eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
  192. eps=config.rms_norm_eps)
  193. def forward(
  194. self,
  195. positions: torch.Tensor,
  196. hidden_states: torch.Tensor,
  197. kv_cache: torch.Tensor,
  198. attn_metadata: AttentionMetadata,
  199. residual: Optional[torch.Tensor],
  200. ) -> Tuple[torch.Tensor, torch.Tensor]:
  201. # Self Attention
  202. if residual is None:
  203. residual = hidden_states
  204. hidden_states = self.input_layernorm(hidden_states)
  205. else:
  206. hidden_states, residual = self.input_layernorm(
  207. hidden_states, residual)
  208. hidden_states = self.self_attn(
  209. positions=positions,
  210. hidden_states=hidden_states,
  211. kv_cache=kv_cache,
  212. attn_metadata=attn_metadata,
  213. )
  214. # Fully Connected
  215. hidden_states, residual = self.post_attention_layernorm(
  216. hidden_states, residual)
  217. hidden_states = self.mlp(hidden_states)
  218. return hidden_states, residual
  219. class GemmaModel(nn.Module):
  220. def __init__(
  221. self,
  222. config: GemmaConfig,
  223. cache_config: Optional[CacheConfig] = None,
  224. quant_config: Optional[QuantizationConfig] = None,
  225. ) -> None:
  226. super().__init__()
  227. self.config = config
  228. self.embed_tokens = VocabParallelEmbedding(
  229. config.vocab_size,
  230. config.hidden_size,
  231. )
  232. self.layers = nn.ModuleList([
  233. GemmaDecoderLayer(config, cache_config, quant_config)
  234. for _ in range(config.num_hidden_layers)
  235. ])
  236. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  237. # Normalize the embedding by sqrt(hidden_size)
  238. # The normalizer's data type should be downcasted to the model's
  239. # data type such as bfloat16, not float32.
  240. # See https://github.com/huggingface/transformers/pull/29402
  241. normalizer = self.config.hidden_size**0.5
  242. self.register_buffer("normalizer", torch.tensor(normalizer))
  243. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  244. return self.embed_tokens(input_ids)
  245. def forward(
  246. self,
  247. input_ids: torch.Tensor,
  248. positions: torch.Tensor,
  249. kv_caches: List[torch.Tensor],
  250. attn_metadata: AttentionMetadata,
  251. intermediate_tensors: Optional[IntermediateTensors] = None,
  252. inputs_embeds: Optional[torch.Tensor] = None,
  253. ) -> torch.Tensor:
  254. if inputs_embeds is not None:
  255. hidden_states = inputs_embeds
  256. else:
  257. hidden_states = self.get_input_embeddings(input_ids)
  258. hidden_states *= self.normalizer
  259. residual = None
  260. for i in range(len(self.layers)):
  261. layer = self.layers[i]
  262. hidden_states, residual = layer(
  263. positions,
  264. hidden_states,
  265. kv_caches[i],
  266. attn_metadata,
  267. residual,
  268. )
  269. hidden_states, _ = self.norm(hidden_states, residual)
  270. return hidden_states
  271. class GemmaForCausalLM(nn.Module, SupportsLoRA):
  272. packed_modules_mapping = {
  273. "qkv_proj": [
  274. "q_proj",
  275. "k_proj",
  276. "v_proj",
  277. ],
  278. "gate_up_proj": [
  279. "gate_proj",
  280. "up_proj",
  281. ],
  282. }
  283. # LoRA specific attributes
  284. supported_lora_modules = [
  285. "qkv_proj",
  286. "o_proj",
  287. "gate_up_proj",
  288. "down_proj",
  289. ]
  290. # Gemma does not apply LoRA to the embedding layer.
  291. embedding_modules = {}
  292. embedding_padding_modules = []
  293. def __init__(
  294. self,
  295. config: GemmaConfig,
  296. cache_config: Optional[CacheConfig] = None,
  297. quant_config: Optional[QuantizationConfig] = None,
  298. lora_config: Optional[LoRAConfig] = None,
  299. ) -> None:
  300. super().__init__()
  301. self.config = config
  302. self.lora_config = lora_config
  303. self.quant_config = quant_config
  304. self.model = GemmaModel(config, cache_config, quant_config)
  305. self.logits_processor = LogitsProcessor(config.vocab_size)
  306. self.sampler = Sampler()
  307. def forward(
  308. self,
  309. input_ids: torch.Tensor,
  310. positions: torch.Tensor,
  311. kv_caches: List[torch.Tensor],
  312. attn_metadata: AttentionMetadata,
  313. intermediate_tensors: Optional[IntermediateTensors] = None,
  314. ) -> torch.Tensor:
  315. hidden_states = self.model(input_ids, positions, kv_caches,
  316. attn_metadata)
  317. return hidden_states
  318. def compute_logits(
  319. self,
  320. hidden_states: torch.Tensor,
  321. sampling_metadata: SamplingMetadata,
  322. ) -> Optional[torch.Tensor]:
  323. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  324. sampling_metadata)
  325. return logits
  326. def sample(
  327. self,
  328. logits: torch.Tensor,
  329. sampling_metadata: SamplingMetadata,
  330. ) -> Optional[SamplerOutput]:
  331. next_tokens = self.sampler(logits, sampling_metadata)
  332. return next_tokens
  333. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  334. stacked_params_mapping = [
  335. # (param_name, shard_name, shard_id)
  336. ("qkv_proj", "q_proj", "q"),
  337. ("qkv_proj", "k_proj", "k"),
  338. ("qkv_proj", "v_proj", "v"),
  339. ("gate_up_proj", "gate_proj", 0),
  340. ("gate_up_proj", "up_proj", 1),
  341. ]
  342. params_dict = dict(self.named_parameters())
  343. loaded_params: Set[str] = set()
  344. weights_list = list(weights)
  345. for name, loaded_weight in progress_bar(weights_list,
  346. desc="Loading modules..."):
  347. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  348. if shard_name not in name:
  349. continue
  350. name = name.replace(shard_name, param_name)
  351. # Skip loading extra bias for GPTQ models.
  352. if name.endswith(".bias") and name not in params_dict:
  353. continue
  354. param = params_dict[name]
  355. weight_loader = param.weight_loader
  356. weight_loader(param, loaded_weight, shard_id)
  357. break
  358. else:
  359. # lm_head is not used in aphrodite as it is tied with
  360. # embed_token. To prevent errors, skip loading lm_head.weight.
  361. if "lm_head.weight" in name:
  362. continue
  363. # Skip loading extra bias for GPTQ models.
  364. if name.endswith(".bias") and name not in params_dict:
  365. continue
  366. param = params_dict[name]
  367. weight_loader = getattr(param, "weight_loader",
  368. default_weight_loader)
  369. weight_loader(param, loaded_weight)
  370. loaded_params.add(name)
  371. unloaded_params = params_dict.keys() - loaded_params
  372. if unloaded_params:
  373. logger.warning(
  374. "Some weights are not initialized from checkpoints: "
  375. f"{unloaded_params}")