gemma.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # coding=utf-8
  2. # Copyright 2023 The PygmalionAI team.
  3. # Copyright 2023 The vLLM team.
  4. # Copyright (c) Google Inc.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Inference-only Gemma model compatible with HuggingFace weights."""
  18. from functools import lru_cache
  19. from typing import Iterable, List, Optional, Tuple
  20. import torch
  21. from loguru import logger
  22. from torch import nn
  23. from transformers import GemmaConfig
  24. from aphrodite.attention import Attention, AttentionMetadata
  25. from aphrodite.common.config import LoRAConfig
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import GeluAndMul
  29. from aphrodite.modeling.layers.layernorm import RMSNorm
  30. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  31. MergedColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.rotary_embedding import get_rope
  36. from aphrodite.modeling.layers.sampler import Sampler
  37. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  38. VocabParallelEmbedding
  39. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  40. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  41. @lru_cache(maxsize=None)
  42. def _get_gemma_act_fn(
  43. hidden_act: Optional[str],
  44. hidden_activation: Optional[str],
  45. ) -> nn.Module:
  46. if hidden_activation is None:
  47. if hidden_act is not None:
  48. logger.warning(
  49. "Gemma's activation function was incorrectly set to exact GeLU "
  50. "in the config JSON file when it was initially released. "
  51. "Changing the activation function to approximate GeLU "
  52. "(`gelu_pytorch_tanh`). If you want to use the legacy "
  53. f"`{hidden_act}`, edit the config JSON to set "
  54. f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
  55. "See https://github.com/huggingface/transformers/pull/29402 "
  56. "for more details.")
  57. return GeluAndMul(approximate="tanh")
  58. elif hidden_activation == "gelu_pytorch_tanh":
  59. return GeluAndMul(approximate="tanh")
  60. elif hidden_activation == "gelu":
  61. return GeluAndMul(approximate="none")
  62. else:
  63. raise ValueError(f"Activation function {hidden_act} is not "
  64. "supported for Gemma models.")
  65. class GemmaMLP(nn.Module):
  66. def __init__(
  67. self,
  68. hidden_size: int,
  69. intermediate_size: int,
  70. hidden_act: Optional[str] = None,
  71. hidden_activation: Optional[str] = None,
  72. linear_method: Optional[LinearMethodBase] = None,
  73. ) -> None:
  74. super().__init__()
  75. self.gate_up_proj = MergedColumnParallelLinear(
  76. hidden_size, [intermediate_size] * 2,
  77. bias=False,
  78. linear_method=linear_method)
  79. self.down_proj = RowParallelLinear(intermediate_size,
  80. hidden_size,
  81. bias=False,
  82. linear_method=linear_method)
  83. self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
  84. def forward(self, x):
  85. gate_up, _ = self.gate_up_proj(x)
  86. x = self.act_fn(gate_up)
  87. x, _ = self.down_proj(x)
  88. return x
  89. class GemmaAttention(nn.Module):
  90. def __init__(self,
  91. hidden_size: int,
  92. num_heads: int,
  93. num_kv_heads: int,
  94. head_dim: int,
  95. max_position_embeddings: int = 8192,
  96. rope_theta: float = 10000,
  97. linear_method: Optional[LinearMethodBase] = None) -> None:
  98. super().__init__()
  99. self.hidden_size = hidden_size
  100. tp_size = get_tensor_model_parallel_world_size()
  101. self.total_num_heads = num_heads
  102. assert self.total_num_heads % tp_size == 0
  103. self.num_heads = self.total_num_heads // tp_size
  104. self.total_num_kv_heads = num_kv_heads
  105. if self.total_num_kv_heads >= tp_size:
  106. # Number of KV heads is greater than TP size, so we partition
  107. # the KV heads across multiple tensor parallel GPUs.
  108. assert self.total_num_kv_heads % tp_size == 0
  109. else:
  110. # Number of KV heads is less than TP size, so we replicate
  111. # the KV heads across multiple tensor parallel GPUs.
  112. assert tp_size % self.total_num_kv_heads == 0
  113. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  114. self.head_dim = head_dim
  115. self.q_size = self.num_heads * self.head_dim
  116. self.kv_size = self.num_kv_heads * self.head_dim
  117. self.scaling = self.head_dim**-0.5
  118. self.rope_theta = rope_theta
  119. self.qkv_proj = QKVParallelLinear(
  120. hidden_size,
  121. self.head_dim,
  122. self.total_num_heads,
  123. self.total_num_kv_heads,
  124. bias=False,
  125. linear_method=linear_method,
  126. )
  127. self.o_proj = RowParallelLinear(
  128. self.total_num_heads * self.head_dim,
  129. hidden_size,
  130. bias=False,
  131. linear_method=linear_method,
  132. )
  133. self.rotary_emb = get_rope(
  134. self.head_dim,
  135. rotary_dim=self.head_dim,
  136. max_position=max_position_embeddings,
  137. base=self.rope_theta,
  138. is_neox_style=True,
  139. )
  140. self.attn = Attention(self.num_heads,
  141. self.head_dim,
  142. self.scaling,
  143. num_kv_heads=self.num_kv_heads)
  144. def forward(
  145. self,
  146. positions: torch.Tensor,
  147. hidden_states: torch.Tensor,
  148. kv_cache: torch.Tensor,
  149. attn_metadata: AttentionMetadata,
  150. ) -> torch.Tensor:
  151. qkv, _ = self.qkv_proj(hidden_states)
  152. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  153. q, k = self.rotary_emb(positions, q, k)
  154. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  155. output, _ = self.o_proj(attn_output)
  156. return output
  157. class GemmaDecoderLayer(nn.Module):
  158. def __init__(
  159. self,
  160. config: GemmaConfig,
  161. linear_method: Optional[LinearMethodBase] = None,
  162. ) -> None:
  163. super().__init__()
  164. self.hidden_size = config.hidden_size
  165. self.self_attn = GemmaAttention(
  166. hidden_size=self.hidden_size,
  167. num_heads=config.num_attention_heads,
  168. num_kv_heads=config.num_key_value_heads,
  169. head_dim=config.head_dim,
  170. max_position_embeddings=config.max_position_embeddings,
  171. rope_theta=config.rope_theta,
  172. linear_method=linear_method,
  173. )
  174. self.mlp = GemmaMLP(
  175. hidden_size=self.hidden_size,
  176. intermediate_size=config.intermediate_size,
  177. hidden_act=config.hidden_act,
  178. hidden_activation=getattr(config, "hidden_activation", None),
  179. linear_method=linear_method,
  180. )
  181. self.input_layernorm = RMSNorm(config.hidden_size,
  182. eps=config.rms_norm_eps)
  183. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  184. eps=config.rms_norm_eps)
  185. def forward(
  186. self,
  187. positions: torch.Tensor,
  188. hidden_states: torch.Tensor,
  189. kv_cache: torch.Tensor,
  190. attn_metadata: AttentionMetadata,
  191. residual: Optional[torch.Tensor],
  192. ) -> Tuple[torch.Tensor, torch.Tensor]:
  193. # Self Attention
  194. if residual is None:
  195. residual = hidden_states
  196. hidden_states = self.input_layernorm(hidden_states)
  197. else:
  198. hidden_states, residual = self.input_layernorm(
  199. hidden_states, residual)
  200. hidden_states = self.self_attn(
  201. positions=positions,
  202. hidden_states=hidden_states,
  203. kv_cache=kv_cache,
  204. attn_metadata=attn_metadata,
  205. )
  206. # Fully Connected
  207. hidden_states, residual = self.post_attention_layernorm(
  208. hidden_states, residual)
  209. hidden_states = self.mlp(hidden_states)
  210. return hidden_states, residual
  211. class GemmaModel(nn.Module):
  212. def __init__(
  213. self,
  214. config: GemmaConfig,
  215. linear_method: Optional[LinearMethodBase] = None,
  216. ) -> None:
  217. super().__init__()
  218. self.config = config
  219. self.embed_tokens = VocabParallelEmbedding(
  220. config.vocab_size,
  221. config.hidden_size,
  222. )
  223. self.layers = nn.ModuleList([
  224. GemmaDecoderLayer(config, linear_method)
  225. for _ in range(config.num_hidden_layers)
  226. ])
  227. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  228. # Normalize the embedding by sqrt(hidden_size)
  229. # The normalizer's data type should be downcasted to the model's
  230. # data type such as bfloat16, not float32.
  231. # See https://github.com/huggingface/transformers/pull/29402
  232. normalizer = self.config.hidden_size**0.5
  233. self.register_buffer("normalizer", torch.tensor(normalizer))
  234. def forward(
  235. self,
  236. input_ids: torch.Tensor,
  237. positions: torch.Tensor,
  238. kv_caches: List[torch.Tensor],
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. hidden_states = self.embed_tokens(input_ids)
  242. hidden_states *= self.normalizer
  243. residual = None
  244. for i in range(len(self.layers)):
  245. layer = self.layers[i]
  246. hidden_states, residual = layer(
  247. positions,
  248. hidden_states,
  249. kv_caches[i],
  250. attn_metadata,
  251. residual,
  252. )
  253. hidden_states, _ = self.norm(hidden_states, residual)
  254. return hidden_states
  255. class GemmaForCausalLM(nn.Module):
  256. packed_modules_mapping = {
  257. "qkv_proj": [
  258. "q_proj",
  259. "k_proj",
  260. "v_proj",
  261. ],
  262. "gate_up_proj": [
  263. "gate_proj",
  264. "up_proj",
  265. ],
  266. }
  267. # LoRA specific attributes
  268. supported_lora_modules = [
  269. "qkv_proj",
  270. "o_proj",
  271. "gate_up_proj",
  272. "down_proj",
  273. ]
  274. # Gemma does not apply LoRA to the embedding layer.
  275. embedding_modules = {}
  276. embedding_padding_modules = []
  277. def __init__(
  278. self,
  279. config: GemmaConfig,
  280. linear_method: Optional[LinearMethodBase] = None,
  281. lora_config: Optional[LoRAConfig] = None,
  282. ) -> None:
  283. del lora_config # Unused.
  284. super().__init__()
  285. self.config = config
  286. self.linear_method = linear_method
  287. self.model = GemmaModel(config, linear_method)
  288. self.logits_processor = LogitsProcessor(config.vocab_size)
  289. self.sampler = Sampler()
  290. @torch.no_grad()
  291. def forward(
  292. self,
  293. input_ids: torch.Tensor,
  294. positions: torch.Tensor,
  295. kv_caches: List[torch.Tensor],
  296. attn_metadata: AttentionMetadata,
  297. ) -> torch.Tensor:
  298. hidden_states = self.model(input_ids, positions, kv_caches,
  299. attn_metadata)
  300. return hidden_states
  301. def compute_logits(self, hidden_states: torch.Tensor,
  302. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  303. logits = self.logits_processor(self.model.embed_tokens.weight,
  304. hidden_states, sampling_metadata)
  305. return logits
  306. def sample(
  307. self,
  308. logits: torch.Tensor,
  309. sampling_metadata: SamplingMetadata,
  310. ) -> Optional[SamplerOutput]:
  311. next_tokens = self.sampler(logits, sampling_metadata)
  312. return next_tokens
  313. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  314. stacked_params_mapping = [
  315. # (param_name, shard_name, shard_id)
  316. ("qkv_proj", "q_proj", "q"),
  317. ("qkv_proj", "k_proj", "k"),
  318. ("qkv_proj", "v_proj", "v"),
  319. ("gate_up_proj", "gate_proj", 0),
  320. ("gate_up_proj", "up_proj", 1),
  321. ]
  322. params_dict = dict(self.named_parameters())
  323. loaded_params = set()
  324. for name, loaded_weight in weights:
  325. for (param_name, shard_name, shard_id) in stacked_params_mapping:
  326. if shard_name not in name:
  327. continue
  328. name = name.replace(shard_name, param_name)
  329. # Skip loading extra bias for GPTQ models.
  330. if name.endswith(".bias") and name not in params_dict:
  331. continue
  332. param = params_dict[name]
  333. weight_loader = param.weight_loader
  334. weight_loader(param, loaded_weight, shard_id)
  335. break
  336. else:
  337. # lm_head is not used in Aphrodite as it is tied with
  338. # embed_token. To prevent errors, skip loading lm_head.weight.
  339. if "lm_head.weight" in name:
  340. continue
  341. # Skip loading extra bias for GPTQ models.
  342. if name.endswith(".bias") and name not in params_dict:
  343. continue
  344. # GemmaRMSNorm is different from Llama's in that it multiplies
  345. # (1 + weight) to the output, instead of just weight.
  346. if "norm.weight" in name:
  347. loaded_weight += 1.0
  348. param = params_dict[name]
  349. weight_loader = getattr(param, "weight_loader",
  350. default_weight_loader)
  351. weight_loader(param, loaded_weight)
  352. loaded_params.add(name)
  353. unloaded_params = params_dict.keys() - loaded_params
  354. if unloaded_params:
  355. raise RuntimeError(
  356. "Some weights are not initialized from checkpoints: "
  357. f"{unloaded_params}")