gemma.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. # coding=utf-8
  2. # Copyright 2023 The vLLM team.
  3. # Copyright (c) Google Inc.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Inference-only Gemma model compatible with HuggingFace weights."""
  17. from functools import lru_cache
  18. from typing import Iterable, List, Optional, Tuple
  19. import torch
  20. from loguru import logger
  21. from torch import nn
  22. from transformers import GemmaConfig
  23. from aphrodite.attention import Attention, AttentionMetadata
  24. from aphrodite.common.config import CacheConfig, LoRAConfig
  25. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  26. from aphrodite.distributed import get_tensor_model_parallel_world_size
  27. from aphrodite.modeling.layers.activation import GeluAndMul
  28. from aphrodite.modeling.layers.layernorm import RMSNorm
  29. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  36. VocabParallelEmbedding
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.models.interfaces import SupportsLoRA
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. @lru_cache(maxsize=None)
  42. def _get_gemma_act_fn(
  43. hidden_act: Optional[str],
  44. hidden_activation: Optional[str],
  45. ) -> nn.Module:
  46. if hidden_activation is None:
  47. if hidden_act is not None:
  48. logger.warning(
  49. "Gemma's activation function was incorrectly set to exact GeLU "
  50. "in the config JSON file when it was initially released. "
  51. "Changing the activation function to approximate GeLU "
  52. "(`gelu_pytorch_tanh`). If you want to use the legacy "
  53. f"`{hidden_act}`, edit the config JSON to set "
  54. f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
  55. "See https://github.com/huggingface/transformers/pull/29402 "
  56. "for more details.")
  57. return GeluAndMul(approximate="tanh")
  58. elif hidden_activation == "gelu_pytorch_tanh":
  59. return GeluAndMul(approximate="tanh")
  60. elif hidden_activation == "gelu":
  61. return GeluAndMul(approximate="none")
  62. else:
  63. raise ValueError(f"Activation function {hidden_act} is not "
  64. "supported for Gemma models.")
  65. class GemmaMLP(nn.Module):
  66. def __init__(
  67. self,
  68. hidden_size: int,
  69. intermediate_size: int,
  70. hidden_act: Optional[str] = None,
  71. hidden_activation: Optional[str] = None,
  72. quant_config: Optional[QuantizationConfig] = None,
  73. ) -> None:
  74. super().__init__()
  75. self.gate_up_proj = MergedColumnParallelLinear(
  76. hidden_size, [intermediate_size] * 2,
  77. bias=False,
  78. quant_config=quant_config)
  79. self.down_proj = RowParallelLinear(intermediate_size,
  80. hidden_size,
  81. bias=False,
  82. quant_config=quant_config)
  83. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  84. def forward(self, x):
  85. gate_up, _ = self.gate_up_proj(x)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.down_proj(x)
  88. return x
  89. class GemmaAttention(nn.Module):
  90. def __init__(self,
  91. hidden_size: int,
  92. num_heads: int,
  93. num_kv_heads: int,
  94. head_dim: int,
  95. max_position_embeddings: int = 8192,
  96. rope_theta: float = 10000,
  97. cache_config: Optional[CacheConfig] = None,
  98. quant_config: Optional[QuantizationConfig] = None) -> None:
  99. super().__init__()
  100. self.hidden_size = hidden_size
  101. tp_size = get_tensor_model_parallel_world_size()
  102. self.total_num_heads = num_heads
  103. assert self.total_num_heads % tp_size == 0
  104. self.num_heads = self.total_num_heads // tp_size
  105. self.total_num_kv_heads = num_kv_heads
  106. if self.total_num_kv_heads >= tp_size:
  107. # Number of KV heads is greater than TP size, so we partition
  108. # the KV heads across multiple tensor parallel GPUs.
  109. assert self.total_num_kv_heads % tp_size == 0
  110. else:
  111. # Number of KV heads is less than TP size, so we replicate
  112. # the KV heads across multiple tensor parallel GPUs.
  113. assert tp_size % self.total_num_kv_heads == 0
  114. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  115. self.head_dim = head_dim
  116. self.q_size = self.num_heads * self.head_dim
  117. self.kv_size = self.num_kv_heads * self.head_dim
  118. self.scaling = self.head_dim**-0.5
  119. self.rope_theta = rope_theta
  120. self.qkv_proj = QKVParallelLinear(
  121. hidden_size,
  122. self.head_dim,
  123. self.total_num_heads,
  124. self.total_num_kv_heads,
  125. bias=False,
  126. quant_config=quant_config,
  127. )
  128. self.o_proj = RowParallelLinear(
  129. self.total_num_heads * self.head_dim,
  130. hidden_size,
  131. bias=False,
  132. quant_config=quant_config,
  133. )
  134. self.rotary_emb = get_rope(
  135. self.head_dim,
  136. rotary_dim=self.head_dim,
  137. max_position=max_position_embeddings,
  138. base=self.rope_theta,
  139. is_neox_style=True,
  140. )
  141. self.attn = Attention(self.num_heads,
  142. self.head_dim,
  143. self.scaling,
  144. num_kv_heads=self.num_kv_heads,
  145. cache_config=cache_config,
  146. quant_config=quant_config)
  147. def forward(
  148. self,
  149. positions: torch.Tensor,
  150. hidden_states: torch.Tensor,
  151. kv_cache: torch.Tensor,
  152. attn_metadata: AttentionMetadata,
  153. ) -> torch.Tensor:
  154. qkv, _ = self.qkv_proj(hidden_states)
  155. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  156. q, k = self.rotary_emb(positions, q, k)
  157. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  158. output, _ = self.o_proj(attn_output)
  159. return output
  160. class GemmaDecoderLayer(nn.Module):
  161. def __init__(
  162. self,
  163. config: GemmaConfig,
  164. cache_config: Optional[CacheConfig] = None,
  165. quant_config: Optional[QuantizationConfig] = None,
  166. ) -> None:
  167. super().__init__()
  168. self.hidden_size = config.hidden_size
  169. self.self_attn = GemmaAttention(
  170. hidden_size=self.hidden_size,
  171. num_heads=config.num_attention_heads,
  172. num_kv_heads=config.num_key_value_heads,
  173. head_dim=config.head_dim,
  174. max_position_embeddings=config.max_position_embeddings,
  175. rope_theta=config.rope_theta,
  176. cache_config=cache_config,
  177. quant_config=quant_config,
  178. )
  179. self.mlp = GemmaMLP(
  180. hidden_size=self.hidden_size,
  181. intermediate_size=config.intermediate_size,
  182. hidden_act=config.hidden_act,
  183. hidden_activation=getattr(config, "hidden_activation", None),
  184. quant_config=quant_config,
  185. )
  186. self.input_layernorm = RMSNorm(config.hidden_size,
  187. eps=config.rms_norm_eps)
  188. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  189. eps=config.rms_norm_eps)
  190. def forward(
  191. self,
  192. positions: torch.Tensor,
  193. hidden_states: torch.Tensor,
  194. kv_cache: torch.Tensor,
  195. attn_metadata: AttentionMetadata,
  196. residual: Optional[torch.Tensor],
  197. ) -> Tuple[torch.Tensor, torch.Tensor]:
  198. # Self Attention
  199. if residual is None:
  200. residual = hidden_states
  201. hidden_states = self.input_layernorm(hidden_states)
  202. else:
  203. hidden_states, residual = self.input_layernorm(
  204. hidden_states, residual)
  205. hidden_states = self.self_attn(
  206. positions=positions,
  207. hidden_states=hidden_states,
  208. kv_cache=kv_cache,
  209. attn_metadata=attn_metadata,
  210. )
  211. # Fully Connected
  212. hidden_states, residual = self.post_attention_layernorm(
  213. hidden_states, residual)
  214. hidden_states = self.mlp(hidden_states)
  215. return hidden_states, residual
  216. class GemmaModel(nn.Module):
  217. def __init__(
  218. self,
  219. config: GemmaConfig,
  220. cache_config: Optional[CacheConfig] = None,
  221. quant_config: Optional[QuantizationConfig] = None,
  222. ) -> None:
  223. super().__init__()
  224. self.config = config
  225. self.embed_tokens = VocabParallelEmbedding(
  226. config.vocab_size,
  227. config.hidden_size,
  228. )
  229. self.layers = nn.ModuleList([
  230. GemmaDecoderLayer(config, cache_config, quant_config)
  231. for _ in range(config.num_hidden_layers)
  232. ])
  233. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  234. # Normalize the embedding by sqrt(hidden_size)
  235. # The normalizer's data type should be downcasted to the model's
  236. # data type such as bfloat16, not float32.
  237. # See https://github.com/huggingface/transformers/pull/29402
  238. normalizer = self.config.hidden_size**0.5
  239. self.register_buffer("normalizer", torch.tensor(normalizer))
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[torch.Tensor],
  245. attn_metadata: AttentionMetadata,
  246. ) -> torch.Tensor:
  247. hidden_states = self.embed_tokens(input_ids)
  248. hidden_states *= self.normalizer
  249. residual = None
  250. for i in range(len(self.layers)):
  251. layer = self.layers[i]
  252. hidden_states, residual = layer(
  253. positions,
  254. hidden_states,
  255. kv_caches[i],
  256. attn_metadata,
  257. residual,
  258. )
  259. hidden_states, _ = self.norm(hidden_states, residual)
  260. return hidden_states
  261. class GemmaForCausalLM(nn.Module, SupportsLoRA):
  262. packed_modules_mapping = {
  263. "qkv_proj": [
  264. "q_proj",
  265. "k_proj",
  266. "v_proj",
  267. ],
  268. "gate_up_proj": [
  269. "gate_proj",
  270. "up_proj",
  271. ],
  272. }
  273. # LoRA specific attributes
  274. supported_lora_modules = [
  275. "qkv_proj",
  276. "o_proj",
  277. "gate_up_proj",
  278. "down_proj",
  279. ]
  280. # Gemma does not apply LoRA to the embedding layer.
  281. embedding_modules = {}
  282. embedding_padding_modules = []
  283. def __init__(
  284. self,
  285. config: GemmaConfig,
  286. cache_config: Optional[CacheConfig] = None,
  287. quant_config: Optional[QuantizationConfig] = None,
  288. lora_config: Optional[LoRAConfig] = None,
  289. ) -> None:
  290. super().__init__()
  291. self.config = config
  292. self.lora_config = lora_config
  293. self.quant_config = quant_config
  294. self.model = GemmaModel(config, cache_config, quant_config)
  295. self.logits_processor = LogitsProcessor(config.vocab_size)
  296. self.sampler = Sampler()
  297. @torch.no_grad()
  298. def forward(
  299. self,
  300. input_ids: torch.Tensor,
  301. positions: torch.Tensor,
  302. kv_caches: List[torch.Tensor],
  303. attn_metadata: AttentionMetadata,
  304. intermediate_tensors: Optional[IntermediateTensors] = None,
  305. ) -> torch.Tensor:
  306. hidden_states = self.model(input_ids, positions, kv_caches,
  307. attn_metadata)
  308. return hidden_states
  309. def compute_logits(self, hidden_states: torch.Tensor,
  310. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  311. logits = self.logits_processor(self.model.embed_tokens, hidden_states,
  312. sampling_metadata)
  313. return logits
  314. def sample(
  315. self,
  316. logits: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[SamplerOutput]:
  319. next_tokens = self.sampler(logits, sampling_metadata)
  320. return next_tokens
  321. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  322. stacked_params_mapping = [
  323. # (param_name, shard_name, shard_id)
  324. ("qkv_proj", "q_proj", "q"),
  325. ("qkv_proj", "k_proj", "k"),
  326. ("qkv_proj", "v_proj", "v"),
  327. ("gate_up_proj", "gate_proj", 0),
  328. ("gate_up_proj", "up_proj", 1),
  329. ]
  330. params_dict = dict(self.named_parameters())
  331. loaded_params = set()
  332. for name, loaded_weight in weights:
  333. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  334. if shard_name not in name:
  335. continue
  336. name = name.replace(shard_name, param_name)
  337. # Skip loading extra bias for GPTQ models.
  338. if name.endswith(".bias") and name not in params_dict:
  339. continue
  340. param = params_dict[name]
  341. weight_loader = param.weight_loader
  342. weight_loader(param, loaded_weight, shard_id)
  343. break
  344. else:
  345. # lm_head is not used in Aphro as it is tied with embed_token.
  346. # To prevent errors, skip loading lm_head.weight.
  347. if "lm_head.weight" in name:
  348. continue
  349. # Skip loading extra bias for GPTQ models.
  350. if name.endswith(".bias") and name not in params_dict:
  351. continue
  352. # GemmaRMSNorm is different from Llama's in that it multiplies
  353. # (1 + weight) to the output, instead of just weight.
  354. if "norm.weight" in name:
  355. loaded_weight += 1.0
  356. param = params_dict[name]
  357. weight_loader = getattr(param, "weight_loader",
  358. default_weight_loader)
  359. weight_loader(param, loaded_weight)
  360. loaded_params.add(name)
  361. unloaded_params = params_dict.keys() - loaded_params
  362. if unloaded_params:
  363. raise RuntimeError(
  364. "Some weights are not initialized from checkpoints: "
  365. f"{unloaded_params}")