chatglm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import List, Optional
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.modeling.layers.activation import SiluAndMul
  12. from aphrodite.modeling.layers.layernorm import RMSNorm
  13. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  14. MergedColumnParallelLinear,
  15. QKVParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  18. from aphrodite.modeling.layers.rotary_embedding import get_rope
  19. from aphrodite.modeling.layers.sampler import Sampler
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. ParallelLMHead, VocabParallelEmbedding)
  22. from aphrodite.distributed import (get_tensor_model_parallel_world_size)
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  25. hf_model_weights_iterator)
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.transformers_utils.configs import ChatGLMConfig
  28. class GLMAttention(nn.Module):
  29. def __init__(
  30. self,
  31. config,
  32. linear_method: Optional[LinearMethodBase] = None,
  33. ):
  34. super().__init__()
  35. self.hidden_size = config.hidden_size
  36. tp_size = get_tensor_model_parallel_world_size()
  37. self.total_num_heads = config.num_attention_heads
  38. assert self.total_num_heads % tp_size == 0
  39. self.num_heads = self.total_num_heads // tp_size
  40. self.multi_query_attention = config.multi_query_attention
  41. self.total_num_kv_heads = (config.multi_query_group_num
  42. if config.multi_query_attention else
  43. config.num_attention_heads)
  44. if self.total_num_kv_heads >= tp_size:
  45. # Number of KV heads is greater than TP size, so we partition
  46. # the KV heads across multiple tensor parallel GPUs.
  47. assert self.total_num_kv_heads % tp_size == 0
  48. else:
  49. # Number of KV heads is less than TP size, so we replicate
  50. # the KV heads across multiple tensor parallel GPUs.
  51. assert tp_size % self.total_num_kv_heads == 0
  52. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  53. self.head_dim = config.hidden_size // self.total_num_heads
  54. self.q_size = self.num_heads * self.head_dim
  55. self.kv_size = self.num_kv_heads * self.head_dim
  56. self.scaling = self.head_dim**-0.5
  57. self.query_key_value = QKVParallelLinear(
  58. self.hidden_size,
  59. self.head_dim,
  60. self.total_num_heads,
  61. self.total_num_kv_heads,
  62. bias=config.add_bias_linear or config.add_qkv_bias,
  63. linear_method=linear_method,
  64. )
  65. self.dense = RowParallelLinear(
  66. self.total_num_heads * self.head_dim,
  67. config.hidden_size,
  68. bias=config.add_bias_linear,
  69. linear_method=linear_method,
  70. )
  71. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  72. rope_ratio = getattr(config, "rope_ratio", 1.0)
  73. max_positions = getattr(config, "seq_length", 8192)
  74. self.rotary_emb = get_rope(
  75. self.head_dim,
  76. rotary_dim=self.head_dim // 2,
  77. max_position=max_positions,
  78. base=10000 * rope_ratio,
  79. is_neox_style=False,
  80. )
  81. self.attn = Attention(
  82. self.num_heads,
  83. self.head_dim,
  84. self.scaling,
  85. num_kv_heads=self.num_kv_heads,
  86. )
  87. def forward(
  88. self,
  89. hidden_states: torch.Tensor,
  90. position_ids: torch.Tensor,
  91. kv_cache: torch.Tensor,
  92. attn_metadata: AttentionMetadata,
  93. ) -> torch.Tensor:
  94. qkv, _ = self.query_key_value(hidden_states)
  95. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  96. q, k = self.rotary_emb(position_ids, q, k)
  97. context_layer = self.attn(
  98. q,
  99. k,
  100. v,
  101. kv_cache,
  102. attn_metadata,
  103. )
  104. attn_output, _ = self.dense(context_layer)
  105. return attn_output
  106. class GLMMLP(nn.Module):
  107. """MLP.
  108. MLP will take the input with h hidden state, project it to 4*h
  109. hidden dimension, perform nonlinear transformation, and project the
  110. state back into h hidden dimension.
  111. """
  112. def __init__(
  113. self,
  114. config,
  115. linear_method: Optional[LinearMethodBase] = None,
  116. ):
  117. super().__init__()
  118. self.add_bias = config.add_bias_linear
  119. # Project to 4h.
  120. self.dense_h_to_4h = MergedColumnParallelLinear(
  121. config.hidden_size,
  122. [config.ffn_hidden_size] * 2,
  123. bias=config.add_bias_linear,
  124. linear_method=linear_method,
  125. )
  126. self.activation_func = SiluAndMul()
  127. # Project back to h.
  128. self.dense_4h_to_h = RowParallelLinear(
  129. config.ffn_hidden_size,
  130. config.hidden_size,
  131. bias=config.add_bias_linear,
  132. linear_method=linear_method,
  133. )
  134. def forward(self, hidden_states):
  135. # [s, b, 4hp]
  136. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  137. intermediate_parallel = self.activation_func(intermediate_parallel)
  138. # [s, b, h]
  139. output, _ = self.dense_4h_to_h(intermediate_parallel)
  140. return output
  141. class GLMBlock(nn.Module):
  142. """A single transformer layer.
  143. Transformer layer takes input with size [s, b, h] and returns an
  144. output of the same size.
  145. """
  146. def __init__(
  147. self,
  148. config,
  149. linear_method: Optional[LinearMethodBase] = None,
  150. ):
  151. super().__init__()
  152. self.apply_residual_connection_post_layernorm = (
  153. config.apply_residual_connection_post_layernorm)
  154. self.fp32_residual_connection = config.fp32_residual_connection
  155. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  156. # Layernorm on the input data.
  157. self.input_layernorm = layer_norm_func(config.hidden_size,
  158. eps=config.layernorm_epsilon)
  159. # Self attention.
  160. self.self_attention = GLMAttention(config, linear_method)
  161. self.hidden_dropout = config.hidden_dropout
  162. # Layernorm on the attention output
  163. self.post_attention_layernorm = layer_norm_func(
  164. config.hidden_size, eps=config.layernorm_epsilon)
  165. # MLP
  166. self.mlp = GLMMLP(config, linear_method)
  167. def forward(
  168. self,
  169. hidden_states: torch.Tensor,
  170. position_ids: torch.Tensor,
  171. kv_cache: torch.Tensor,
  172. attn_metadata: AttentionMetadata,
  173. ) -> torch.Tensor:
  174. # hidden_states: [num_tokens, h]
  175. # Layer norm at the beginning of the transformer layer.
  176. layernorm_output = self.input_layernorm(hidden_states)
  177. # Self attention.
  178. attention_output = self.self_attention(
  179. hidden_states=layernorm_output,
  180. position_ids=position_ids,
  181. kv_cache=kv_cache,
  182. attn_metadata=attn_metadata,
  183. )
  184. # Residual connection.
  185. if self.apply_residual_connection_post_layernorm:
  186. residual = layernorm_output
  187. else:
  188. residual = hidden_states
  189. layernorm_input = residual + attention_output
  190. # Layer norm post the self attention.
  191. layernorm_output = self.post_attention_layernorm(layernorm_input)
  192. # Second residual connection.
  193. if self.apply_residual_connection_post_layernorm:
  194. residual = layernorm_output
  195. else:
  196. residual = layernorm_input
  197. output = self.mlp(layernorm_output) + residual
  198. return output
  199. class GLMTransformer(nn.Module):
  200. """Transformer class."""
  201. def __init__(
  202. self,
  203. config,
  204. linear_method: Optional[LinearMethodBase] = None,
  205. ):
  206. super().__init__()
  207. self.post_layer_norm = config.post_layer_norm
  208. # Number of layers.
  209. self.num_layers = config.num_layers
  210. # Transformer layers.
  211. self.layers = nn.ModuleList(
  212. [GLMBlock(config, linear_method) for i in range(self.num_layers)])
  213. if self.post_layer_norm:
  214. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  215. # Final layer norm before output.
  216. self.final_layernorm = layer_norm_func(
  217. config.hidden_size, eps=config.layernorm_epsilon)
  218. def forward(
  219. self,
  220. hidden_states: torch.Tensor,
  221. position_ids: torch.Tensor,
  222. kv_caches: List[torch.Tensor],
  223. attn_metadata: AttentionMetadata,
  224. ) -> torch.Tensor:
  225. for i in range(self.num_layers):
  226. layer = self.layers[i]
  227. hidden_states = layer(
  228. hidden_states=hidden_states,
  229. position_ids=position_ids,
  230. kv_cache=kv_caches[i],
  231. attn_metadata=attn_metadata,
  232. )
  233. # Final layer norm.
  234. if self.post_layer_norm:
  235. hidden_states = self.final_layernorm(hidden_states)
  236. return hidden_states
  237. class ChatGLMModel(nn.Module):
  238. def __init__(
  239. self,
  240. config,
  241. linear_method: Optional[LinearMethodBase] = None,
  242. ):
  243. super().__init__()
  244. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  245. config.hidden_size,
  246. linear_method=linear_method)
  247. self.num_layers = config.num_layers
  248. self.multi_query_group_num = config.multi_query_group_num
  249. self.kv_channels = config.kv_channels
  250. self.encoder = GLMTransformer(config, linear_method)
  251. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  252. config.hidden_size,
  253. linear_method=linear_method)
  254. def forward(
  255. self,
  256. input_ids: torch.Tensor,
  257. position_ids: torch.Tensor,
  258. kv_caches: List[torch.Tensor],
  259. attn_metadata: AttentionMetadata,
  260. ) -> torch.Tensor:
  261. inputs_embeds = self.embedding(input_ids)
  262. # Run encoder.
  263. hidden_states = self.encoder(
  264. hidden_states=inputs_embeds,
  265. position_ids=position_ids,
  266. kv_caches=kv_caches,
  267. attn_metadata=attn_metadata,
  268. )
  269. return hidden_states
  270. class ChatGLMForCausalLM(nn.Module):
  271. packed_modules_mapping = {
  272. "query_key_value": ["query_key_value"],
  273. "dense_h_to_4h": ["dense_h_to_4h"]
  274. }
  275. # LoRA specific attributes
  276. supported_lora_modules = [
  277. "query_key_value",
  278. "dense",
  279. "dense_h_to_4h",
  280. "dense_4h_to_h",
  281. ]
  282. embedding_modules = {}
  283. embedding_padding_modules = []
  284. def __init__(
  285. self,
  286. config: ChatGLMConfig,
  287. linear_method: Optional[LinearMethodBase] = None,
  288. lora_config: Optional[LoRAConfig] = None,
  289. ):
  290. super().__init__()
  291. self.config: ChatGLMConfig = config
  292. self.linear_method = linear_method
  293. self.transformer = ChatGLMModel(config, linear_method)
  294. self.logits_processor = LogitsProcessor(config.padded_vocab_size,
  295. config.tokenizer_vocab_size)
  296. self.sampler = Sampler()
  297. def forward(
  298. self,
  299. input_ids: torch.Tensor,
  300. positions: torch.Tensor,
  301. kv_caches: List[torch.Tensor],
  302. attn_metadata: AttentionMetadata,
  303. ) -> torch.Tensor:
  304. hidden_states = self.transformer(input_ids, positions, kv_caches,
  305. attn_metadata)
  306. return hidden_states
  307. def compute_logits(self, hidden_states: torch.Tensor,
  308. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  309. logits = self.logits_processor(self.transformer.output_layer,
  310. hidden_states, sampling_metadata)
  311. return logits
  312. def sample(
  313. self,
  314. logits: torch.Tensor,
  315. sampling_metadata: SamplingMetadata,
  316. ) -> Optional[SamplerOutput]:
  317. next_tokens = self.sampler(logits, sampling_metadata)
  318. return next_tokens
  319. def load_weights(self,
  320. model_name_or_path: str,
  321. cache_dir: Optional[str] = None,
  322. load_format: str = "auto",
  323. revision: Optional[str] = None):
  324. params_dict = dict(self.named_parameters(remove_duplicate=False))
  325. for name, loaded_weight in hf_model_weights_iterator(
  326. model_name_or_path, cache_dir, load_format, revision,
  327. self.config):
  328. if "rotary_pos_emb.inv_freq" in name:
  329. continue
  330. if "word_embeddings" in name:
  331. name = name.replace(".word_embeddings", "")
  332. # Skip loading extra bias for GPTQ models.
  333. if name.endswith(".bias") and name not in params_dict:
  334. continue
  335. param = params_dict[name]
  336. weight_loader = getattr(param, "weight_loader",
  337. default_weight_loader)
  338. weight_loader(param, loaded_weight)