gemma.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # coding=utf-8
  2. # Copyright 2023 The PygmalionAI team.
  3. # Copyright 2023 The vLLM team.
  4. # Copyright (c) Google Inc.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Inference-only Gemma model compatible with HuggingFace weights."""
  18. from typing import List, Optional, Tuple
  19. import torch
  20. from torch import nn
  21. from transformers import GemmaConfig
  22. from aphrodite.modeling.metadata import InputMetadata
  23. from aphrodite.modeling.layers.activation import GeluAndMul
  24. from aphrodite.modeling.layers.attention import PagedAttention
  25. from aphrodite.modeling.layers.layernorm import RMSNorm
  26. from aphrodite.modeling.layers.linear import (
  27. LinearMethodBase,
  28. MergedColumnParallelLinear,
  29. QKVParallelLinear,
  30. RowParallelLinear,
  31. ColumnParallelLinear,
  32. )
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding,
  37. ParallelLMHead,
  38. )
  39. from aphrodite.modeling.megatron.parallel_state import (
  40. get_tensor_model_parallel_world_size, )
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.modeling.hf_downloader import (
  43. default_weight_loader,
  44. hf_model_weights_iterator,
  45. )
  46. from aphrodite.common.sequence import SamplerOutput
  47. KVCache = Tuple[torch.Tensor, torch.Tensor]
  48. class GemmaMLP(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. linear_method: Optional[LinearMethodBase] = None,
  54. ) -> None:
  55. super().__init__()
  56. if (linear_method is not None
  57. and not linear_method.quant_config.merge_weight()):
  58. self.merge_weight = False
  59. self.gate_proj = ColumnParallelLinear(
  60. hidden_size,
  61. intermediate_size,
  62. bias=False,
  63. linear_method=linear_method,
  64. )
  65. self.up_proj = ColumnParallelLinear(
  66. hidden_size,
  67. intermediate_size,
  68. bias=False,
  69. linear_method=linear_method,
  70. )
  71. else:
  72. self.merge_weight = True
  73. self.gate_up_proj = MergedColumnParallelLinear(
  74. hidden_size,
  75. [intermediate_size] * 2,
  76. bias=False,
  77. linear_method=linear_method,
  78. )
  79. self.down_proj = RowParallelLinear(
  80. intermediate_size,
  81. hidden_size,
  82. bias=False,
  83. linear_method=linear_method,
  84. )
  85. self.act_fn = GeluAndMul()
  86. def forward(self, x):
  87. if self.merge_weight:
  88. gate_up, _ = self.gate_up_proj(x)
  89. else:
  90. up, _ = self.up_proj(x)
  91. gate, _ = self.gate_proj(x)
  92. gate_up = torch.cat([gate, up], dim=-1)
  93. x = self.act_fn(gate_up)
  94. x, _ = self.down_proj(x)
  95. return x
  96. class GemmaAttention(nn.Module):
  97. def __init__(
  98. self,
  99. hidden_size: int,
  100. num_heads: int,
  101. num_kv_heads: int,
  102. head_dim: int,
  103. max_position_embeddings: int = 8192,
  104. rope_theta: float = 10000,
  105. linear_method: Optional[LinearMethodBase] = None,
  106. ) -> None:
  107. super().__init__()
  108. self.hidden_size = hidden_size
  109. tp_size = get_tensor_model_parallel_world_size()
  110. self.total_num_heads = num_heads
  111. assert self.total_num_heads % tp_size == 0
  112. self.num_heads = self.total_num_heads // tp_size
  113. self.total_num_kv_heads = num_kv_heads
  114. if self.total_num_kv_heads >= tp_size:
  115. # Number of KV heads is greater than TP size, so we partition
  116. # the KV heads across multiple tensor parallel GPUs.
  117. assert self.total_num_kv_heads % tp_size == 0
  118. else:
  119. # Number of KV heads is less than TP size, so we replicate
  120. # the KV heads across multiple tensor parallel GPUs.
  121. assert tp_size % self.total_num_kv_heads == 0
  122. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  123. self.head_dim = head_dim
  124. self.q_size = self.num_heads * self.head_dim
  125. self.kv_size = self.num_kv_heads * self.head_dim
  126. self.scaling = self.head_dim**-0.5
  127. self.rope_theta = rope_theta
  128. if (linear_method is not None
  129. and not linear_method.quant_config.merge_weight()):
  130. self.merge_weight = False
  131. self.q_proj = ColumnParallelLinear(
  132. hidden_size,
  133. self.q_size,
  134. bias=False,
  135. linear_method=linear_method,
  136. )
  137. self.k_proj = ColumnParallelLinear(
  138. hidden_size,
  139. self.kv_size,
  140. bias=False,
  141. linear_method=linear_method,
  142. )
  143. self.v_proj = ColumnParallelLinear(
  144. hidden_size,
  145. self.kv_size,
  146. bias=False,
  147. linear_method=linear_method,
  148. )
  149. else:
  150. self.merge_weight = True
  151. self.qkv_proj = QKVParallelLinear(
  152. hidden_size,
  153. self.head_dim,
  154. self.total_num_heads,
  155. self.total_num_kv_heads,
  156. bias=False,
  157. linear_method=linear_method,
  158. )
  159. self.o_proj = RowParallelLinear(
  160. self.total_num_heads * self.head_dim,
  161. hidden_size,
  162. bias=False,
  163. linear_method=linear_method,
  164. )
  165. self.rotary_emb = get_rope(
  166. self.head_dim,
  167. rotary_dim=self.head_dim,
  168. max_position=max_position_embeddings,
  169. base=self.rope_theta,
  170. is_neox_style=True,
  171. )
  172. self.attn = PagedAttention(
  173. self.num_heads,
  174. self.head_dim,
  175. self.scaling,
  176. num_kv_heads=self.num_kv_heads,
  177. )
  178. def forward(
  179. self,
  180. positions: torch.Tensor,
  181. hidden_states: torch.Tensor,
  182. kv_cache: KVCache,
  183. input_metadata: InputMetadata,
  184. ) -> torch.Tensor:
  185. if self.merge_weight:
  186. qkv, _ = self.qkv_proj(hidden_states)
  187. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  188. dim=-1)
  189. else:
  190. q, _ = self.q_proj(hidden_states)
  191. k, _ = self.k_proj(hidden_states)
  192. v, _ = self.v_proj(hidden_states)
  193. q, k = self.rotary_emb(positions, q, k)
  194. k_cache, v_cache = kv_cache
  195. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  196. output, _ = self.o_proj(attn_output)
  197. return output
  198. class GemmaDecoderLayer(nn.Module):
  199. def __init__(
  200. self,
  201. config: GemmaConfig,
  202. linear_method: Optional[LinearMethodBase] = None,
  203. ) -> None:
  204. super().__init__()
  205. self.hidden_size = config.hidden_size
  206. self.self_attn = GemmaAttention(
  207. hidden_size=self.hidden_size,
  208. num_heads=config.num_attention_heads,
  209. num_kv_heads=config.num_key_value_heads,
  210. head_dim=config.head_dim,
  211. max_position_embeddings=config.max_position_embeddings,
  212. rope_theta=config.rope_theta,
  213. linear_method=linear_method,
  214. )
  215. self.mlp = GemmaMLP(
  216. hidden_size=self.hidden_size,
  217. intermediate_size=config.intermediate_size,
  218. linear_method=linear_method,
  219. )
  220. self.input_layernorm = RMSNorm(config.hidden_size,
  221. eps=config.rms_norm_eps)
  222. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  223. eps=config.rms_norm_eps)
  224. def forward(
  225. self,
  226. positions: torch.Tensor,
  227. hidden_states: torch.Tensor,
  228. kv_cache: KVCache,
  229. input_metadata: InputMetadata,
  230. residual: Optional[torch.Tensor],
  231. ) -> Tuple[torch.Tensor, torch.Tensor]:
  232. # Self Attention
  233. if residual is None:
  234. residual = hidden_states
  235. hidden_states = self.input_layernorm(hidden_states)
  236. else:
  237. hidden_states, residual = self.input_layernorm(
  238. hidden_states, residual)
  239. hidden_states = self.self_attn(
  240. positions=positions,
  241. hidden_states=hidden_states,
  242. kv_cache=kv_cache,
  243. input_metadata=input_metadata,
  244. )
  245. # Fully Connected
  246. hidden_states, residual = self.post_attention_layernorm(
  247. hidden_states, residual)
  248. hidden_states = self.mlp(hidden_states)
  249. return hidden_states, residual
  250. class GemmaModel(nn.Module):
  251. def __init__(
  252. self,
  253. config: GemmaConfig,
  254. linear_method: Optional[LinearMethodBase] = None,
  255. ) -> None:
  256. super().__init__()
  257. self.config = config
  258. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  259. config.hidden_size,
  260. linear_method=linear_method)
  261. self.layers = nn.ModuleList([
  262. GemmaDecoderLayer(config, linear_method)
  263. for _ in range(config.num_hidden_layers)
  264. ])
  265. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  266. def forward(
  267. self,
  268. input_ids: torch.Tensor,
  269. positions: torch.Tensor,
  270. kv_caches: List[KVCache],
  271. input_metadata: InputMetadata,
  272. ) -> torch.Tensor:
  273. hidden_states = self.embed_tokens(input_ids)
  274. # Normalize the embedding by sqrt(hidden_size)
  275. hidden_states *= self.config.hidden_size**0.5
  276. residual = None
  277. for i in range(len(self.layers)):
  278. layer = self.layers[i]
  279. hidden_states, residual = layer(
  280. positions,
  281. hidden_states,
  282. kv_caches[i],
  283. input_metadata,
  284. residual,
  285. )
  286. hidden_states, _ = self.norm(hidden_states, residual)
  287. return hidden_states
  288. class GemmaForCausalLM(nn.Module):
  289. def __init__(
  290. self,
  291. config: GemmaConfig,
  292. linear_method: Optional[LinearMethodBase] = None,
  293. ) -> None:
  294. super().__init__()
  295. self.config = config
  296. self.linear_method = linear_method
  297. self.model = GemmaModel(config, linear_method)
  298. self.lm_head = ParallelLMHead(config.vocab_size,
  299. config.hidden_size,
  300. linear_method=linear_method)
  301. self.sampler = Sampler(config.vocab_size)
  302. self.quant_sampler = QuantSampler(config.vocab_size)
  303. @torch.no_grad()
  304. def forward(
  305. self,
  306. input_ids: torch.Tensor,
  307. positions: torch.Tensor,
  308. kv_caches: List[KVCache],
  309. input_metadata: InputMetadata,
  310. ) -> torch.Tensor:
  311. hidden_states = self.model(input_ids, positions, kv_caches,
  312. input_metadata)
  313. return hidden_states
  314. def sample(
  315. self,
  316. hidden_states: torch.Tensor,
  317. sampling_metadata: SamplingMetadata,
  318. ) -> Optional[SamplerOutput]:
  319. if (self.linear_method is not None
  320. and not self.linear_method.quant_config.merge_weight()):
  321. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  322. sampling_metadata)
  323. else:
  324. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  325. sampling_metadata)
  326. return next_tokens
  327. def load_weights(
  328. self,
  329. model_name_or_path: str,
  330. cache_dir: Optional[str] = None,
  331. load_format: str = "auto",
  332. revision: Optional[str] = None,
  333. ):
  334. stacked_params_mapping = [
  335. # (param_name, shard_name, shard_id)
  336. ("qkv_proj", "q_proj", "q"),
  337. ("qkv_proj", "k_proj", "k"),
  338. ("qkv_proj", "v_proj", "v"),
  339. ("gate_up_proj", "gate_proj", 0),
  340. ("gate_up_proj", "up_proj", 1),
  341. ]
  342. if (self.linear_method is not None
  343. and not self.linear_method.quant_config.merge_weight()):
  344. stacked_params_mapping = []
  345. params_dict = dict(self.named_parameters())
  346. loaded_params = set()
  347. for name, loaded_weight in hf_model_weights_iterator(
  348. model_name_or_path, cache_dir, load_format, revision,
  349. self.config):
  350. if "rotary_emb.inv_freq" in name:
  351. continue
  352. if "embed_tokens" in name:
  353. # Copy word embedding to lm_head
  354. head_name = name.replace("model.embed_tokens", "lm_head")
  355. if head_name in params_dict:
  356. loaded_params.add(head_name)
  357. lm_head_param = params_dict[head_name]
  358. weight_loader = getattr(lm_head_param, "weight_loader",
  359. default_weight_loader)
  360. weight_loader(lm_head_param, loaded_weight)
  361. for param_name, weight_name, shard_id in stacked_params_mapping:
  362. if weight_name not in name:
  363. continue
  364. name = name.replace(weight_name, param_name)
  365. # Skip loading extra bias for GPTQ models.
  366. if name.endswith(".bias") and name not in params_dict:
  367. continue
  368. param = params_dict[name]
  369. weight_loader = param.weight_loader
  370. weight_loader(param, loaded_weight, shard_id)
  371. break
  372. else:
  373. # Skip loading extra layer for lora models.
  374. if "lm_head" in name and name not in params_dict:
  375. continue
  376. # GemmaRMSNorm is different from Llama's in that it multiplies
  377. # (1 + weight) to the output, instead of just weight.
  378. if "norm.weight" in name:
  379. loaded_weight += 1.0
  380. param = params_dict[name]
  381. weight_loader = getattr(param, "weight_loader",
  382. default_weight_loader)
  383. weight_loader(param, loaded_weight)
  384. loaded_params.add(name)
  385. unloaded_params = params_dict.keys() - loaded_params
  386. if unloaded_params:
  387. raise RuntimeError(
  388. "Some weights are not initialized from checkpoints: "
  389. f"{unloaded_params}")