chatglm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import Iterable, List, Optional, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import LoRAConfig
  11. from aphrodite.common.sequence import SamplerOutput
  12. from aphrodite.distributed import get_tensor_model_parallel_world_size
  13. from aphrodite.modeling.layers.activation import SiluAndMul
  14. from aphrodite.modeling.layers.layernorm import RMSNorm
  15. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  16. MergedColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  20. from aphrodite.modeling.layers.rotary_embedding import get_rope
  21. from aphrodite.modeling.layers.sampler import Sampler
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. ParallelLMHead, VocabParallelEmbedding)
  24. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.transformers_utils.configs import ChatGLMConfig
  27. class GLMAttention(nn.Module):
  28. def __init__(
  29. self,
  30. config,
  31. linear_method: Optional[LinearMethodBase] = None,
  32. ):
  33. super().__init__()
  34. self.hidden_size = config.hidden_size
  35. tp_size = get_tensor_model_parallel_world_size()
  36. self.total_num_heads = config.num_attention_heads
  37. assert self.total_num_heads % tp_size == 0
  38. self.num_heads = self.total_num_heads // tp_size
  39. self.multi_query_attention = config.multi_query_attention
  40. self.total_num_kv_heads = (config.multi_query_group_num
  41. if config.multi_query_attention else
  42. config.num_attention_heads)
  43. if self.total_num_kv_heads >= tp_size:
  44. # Number of KV heads is greater than TP size, so we partition
  45. # the KV heads across multiple tensor parallel GPUs.
  46. assert self.total_num_kv_heads % tp_size == 0
  47. else:
  48. # Number of KV heads is less than TP size, so we replicate
  49. # the KV heads across multiple tensor parallel GPUs.
  50. assert tp_size % self.total_num_kv_heads == 0
  51. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  52. self.head_dim = config.hidden_size // self.total_num_heads
  53. self.q_size = self.num_heads * self.head_dim
  54. self.kv_size = self.num_kv_heads * self.head_dim
  55. self.scaling = self.head_dim**-0.5
  56. self.query_key_value = QKVParallelLinear(
  57. self.hidden_size,
  58. self.head_dim,
  59. self.total_num_heads,
  60. self.total_num_kv_heads,
  61. bias=config.add_bias_linear or config.add_qkv_bias,
  62. linear_method=linear_method,
  63. )
  64. self.dense = RowParallelLinear(
  65. self.total_num_heads * self.head_dim,
  66. config.hidden_size,
  67. bias=config.add_bias_linear,
  68. linear_method=linear_method,
  69. )
  70. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  71. rope_ratio = getattr(config, "rope_ratio", 1.0)
  72. max_positions = getattr(config, "seq_length", 8192)
  73. self.rotary_emb = get_rope(
  74. self.head_dim,
  75. rotary_dim=self.head_dim // 2,
  76. max_position=max_positions,
  77. base=10000 * rope_ratio,
  78. is_neox_style=False,
  79. )
  80. self.attn = Attention(
  81. self.num_heads,
  82. self.head_dim,
  83. self.scaling,
  84. num_kv_heads=self.num_kv_heads,
  85. )
  86. def forward(
  87. self,
  88. hidden_states: torch.Tensor,
  89. position_ids: torch.Tensor,
  90. kv_cache: torch.Tensor,
  91. attn_metadata: AttentionMetadata,
  92. ) -> torch.Tensor:
  93. qkv, _ = self.query_key_value(hidden_states)
  94. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  95. q, k = self.rotary_emb(position_ids, q, k)
  96. context_layer = self.attn(
  97. q,
  98. k,
  99. v,
  100. kv_cache,
  101. attn_metadata,
  102. )
  103. attn_output, _ = self.dense(context_layer)
  104. return attn_output
  105. class GLMMLP(nn.Module):
  106. """MLP.
  107. MLP will take the input with h hidden state, project it to 4*h
  108. hidden dimension, perform nonlinear transformation, and project the
  109. state back into h hidden dimension.
  110. """
  111. def __init__(
  112. self,
  113. config,
  114. linear_method: Optional[LinearMethodBase] = None,
  115. ):
  116. super().__init__()
  117. self.add_bias = config.add_bias_linear
  118. # Project to 4h.
  119. self.dense_h_to_4h = MergedColumnParallelLinear(
  120. config.hidden_size,
  121. [config.ffn_hidden_size] * 2,
  122. bias=config.add_bias_linear,
  123. linear_method=linear_method,
  124. )
  125. self.activation_func = SiluAndMul()
  126. # Project back to h.
  127. self.dense_4h_to_h = RowParallelLinear(
  128. config.ffn_hidden_size,
  129. config.hidden_size,
  130. bias=config.add_bias_linear,
  131. linear_method=linear_method,
  132. )
  133. def forward(self, hidden_states):
  134. # [s, b, 4hp]
  135. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  136. intermediate_parallel = self.activation_func(intermediate_parallel)
  137. # [s, b, h]
  138. output, _ = self.dense_4h_to_h(intermediate_parallel)
  139. return output
  140. class GLMBlock(nn.Module):
  141. """A single transformer layer.
  142. Transformer layer takes input with size [s, b, h] and returns an
  143. output of the same size.
  144. """
  145. def __init__(
  146. self,
  147. config,
  148. linear_method: Optional[LinearMethodBase] = None,
  149. ):
  150. super().__init__()
  151. self.apply_residual_connection_post_layernorm = (
  152. config.apply_residual_connection_post_layernorm)
  153. self.fp32_residual_connection = config.fp32_residual_connection
  154. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  155. # Layernorm on the input data.
  156. self.input_layernorm = layer_norm_func(config.hidden_size,
  157. eps=config.layernorm_epsilon)
  158. # Self attention.
  159. self.self_attention = GLMAttention(config, linear_method)
  160. self.hidden_dropout = config.hidden_dropout
  161. # Layernorm on the attention output
  162. self.post_attention_layernorm = layer_norm_func(
  163. config.hidden_size, eps=config.layernorm_epsilon)
  164. # MLP
  165. self.mlp = GLMMLP(config, linear_method)
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. position_ids: torch.Tensor,
  170. kv_cache: torch.Tensor,
  171. attn_metadata: AttentionMetadata,
  172. ) -> torch.Tensor:
  173. # hidden_states: [num_tokens, h]
  174. # Layer norm at the beginning of the transformer layer.
  175. layernorm_output = self.input_layernorm(hidden_states)
  176. # Self attention.
  177. attention_output = self.self_attention(
  178. hidden_states=layernorm_output,
  179. position_ids=position_ids,
  180. kv_cache=kv_cache,
  181. attn_metadata=attn_metadata,
  182. )
  183. # Residual connection.
  184. if self.apply_residual_connection_post_layernorm:
  185. residual = layernorm_output
  186. else:
  187. residual = hidden_states
  188. layernorm_input = residual + attention_output
  189. # Layer norm post the self attention.
  190. layernorm_output = self.post_attention_layernorm(layernorm_input)
  191. # Second residual connection.
  192. if self.apply_residual_connection_post_layernorm:
  193. residual = layernorm_output
  194. else:
  195. residual = layernorm_input
  196. output = self.mlp(layernorm_output) + residual
  197. return output
  198. class GLMTransformer(nn.Module):
  199. """Transformer class."""
  200. def __init__(
  201. self,
  202. config,
  203. linear_method: Optional[LinearMethodBase] = None,
  204. ):
  205. super().__init__()
  206. self.post_layer_norm = config.post_layer_norm
  207. # Number of layers.
  208. self.num_layers = config.num_layers
  209. # Transformer layers.
  210. self.layers = nn.ModuleList(
  211. [GLMBlock(config, linear_method) for i in range(self.num_layers)])
  212. if self.post_layer_norm:
  213. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  214. # Final layer norm before output.
  215. self.final_layernorm = layer_norm_func(
  216. config.hidden_size, eps=config.layernorm_epsilon)
  217. def forward(
  218. self,
  219. hidden_states: torch.Tensor,
  220. position_ids: torch.Tensor,
  221. kv_caches: List[torch.Tensor],
  222. attn_metadata: AttentionMetadata,
  223. ) -> torch.Tensor:
  224. for i in range(self.num_layers):
  225. layer = self.layers[i]
  226. hidden_states = layer(
  227. hidden_states=hidden_states,
  228. position_ids=position_ids,
  229. kv_cache=kv_caches[i],
  230. attn_metadata=attn_metadata,
  231. )
  232. # Final layer norm.
  233. if self.post_layer_norm:
  234. hidden_states = self.final_layernorm(hidden_states)
  235. return hidden_states
  236. class ChatGLMModel(nn.Module):
  237. def __init__(
  238. self,
  239. config,
  240. linear_method: Optional[LinearMethodBase] = None,
  241. ):
  242. super().__init__()
  243. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  244. config.hidden_size)
  245. self.num_layers = config.num_layers
  246. self.multi_query_group_num = config.multi_query_group_num
  247. self.kv_channels = config.kv_channels
  248. self.encoder = GLMTransformer(config, linear_method)
  249. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  250. config.hidden_size)
  251. def forward(
  252. self,
  253. input_ids: torch.Tensor,
  254. position_ids: torch.Tensor,
  255. kv_caches: List[torch.Tensor],
  256. attn_metadata: AttentionMetadata,
  257. ) -> torch.Tensor:
  258. inputs_embeds = self.embedding(input_ids)
  259. # Run encoder.
  260. hidden_states = self.encoder(
  261. hidden_states=inputs_embeds,
  262. position_ids=position_ids,
  263. kv_caches=kv_caches,
  264. attn_metadata=attn_metadata,
  265. )
  266. return hidden_states
  267. class ChatGLMForCausalLM(nn.Module):
  268. packed_modules_mapping = {
  269. "query_key_value": ["query_key_value"],
  270. "dense_h_to_4h": ["dense_h_to_4h"]
  271. }
  272. # LoRA specific attributes
  273. supported_lora_modules = [
  274. "query_key_value",
  275. "dense",
  276. "dense_h_to_4h",
  277. "dense_4h_to_h",
  278. ]
  279. embedding_modules = {}
  280. embedding_padding_modules = []
  281. def __init__(
  282. self,
  283. config: ChatGLMConfig,
  284. linear_method: Optional[LinearMethodBase] = None,
  285. lora_config: Optional[LoRAConfig] = None,
  286. ):
  287. super().__init__()
  288. self.config: ChatGLMConfig = config
  289. self.linear_method = linear_method
  290. self.transformer = ChatGLMModel(config, linear_method)
  291. self.lm_head_weight = self.transformer.output_layer.weight
  292. self.logits_processor = LogitsProcessor(config.padded_vocab_size)
  293. self.sampler = Sampler()
  294. def forward(
  295. self,
  296. input_ids: torch.Tensor,
  297. positions: torch.Tensor,
  298. kv_caches: List[torch.Tensor],
  299. attn_metadata: AttentionMetadata,
  300. ) -> torch.Tensor:
  301. hidden_states = self.transformer(input_ids, positions, kv_caches,
  302. attn_metadata)
  303. return hidden_states
  304. def compute_logits(self, hidden_states: torch.Tensor,
  305. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  306. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  307. sampling_metadata)
  308. return logits
  309. def sample(
  310. self,
  311. logits: torch.Tensor,
  312. sampling_metadata: SamplingMetadata,
  313. ) -> Optional[SamplerOutput]:
  314. next_tokens = self.sampler(logits, sampling_metadata)
  315. return next_tokens
  316. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  317. params_dict = dict(self.named_parameters(remove_duplicate=False))
  318. for name, loaded_weight in weights:
  319. if "rotary_pos_emb.inv_freq" in name:
  320. continue
  321. if "word_embeddings" in name:
  322. name = name.replace(".word_embeddings", "")
  323. # Skip loading extra bias for GPTQ models.
  324. if name.endswith(".bias") and name not in params_dict:
  325. continue
  326. param = params_dict[name]
  327. weight_loader = getattr(param, "weight_loader",
  328. default_weight_loader)
  329. weight_loader(param, loaded_weight)