chatglm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import Iterable, List, Optional, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, LoRAConfig
  11. from aphrodite.common.sequence import IntermediateTensors
  12. from aphrodite.distributed import get_tensor_model_parallel_world_size
  13. from aphrodite.modeling.layers.activation import SiluAndMul
  14. from aphrodite.modeling.layers.layernorm import RMSNorm
  15. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  19. from aphrodite.modeling.layers.rotary_embedding import get_rope
  20. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  22. ParallelLMHead, VocabParallelEmbedding)
  23. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  24. from aphrodite.modeling.models.interfaces import SupportsLoRA
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.quantization.base_config import QuantizationConfig
  27. from aphrodite.transformers_utils.configs import ChatGLMConfig
  28. class GLMAttention(nn.Module):
  29. def __init__(
  30. self,
  31. config,
  32. cache_config: Optional[CacheConfig] = None,
  33. quant_config: Optional[QuantizationConfig] = None,
  34. ):
  35. super().__init__()
  36. self.hidden_size = config.hidden_size
  37. tp_size = get_tensor_model_parallel_world_size()
  38. self.total_num_heads = config.num_attention_heads
  39. assert self.total_num_heads % tp_size == 0
  40. self.num_heads = self.total_num_heads // tp_size
  41. self.multi_query_attention = config.multi_query_attention
  42. self.total_num_kv_heads = (config.multi_query_group_num
  43. if config.multi_query_attention else
  44. config.num_attention_heads)
  45. if self.total_num_kv_heads >= tp_size:
  46. # Number of KV heads is greater than TP size, so we partition
  47. # the KV heads across multiple tensor parallel GPUs.
  48. assert self.total_num_kv_heads % tp_size == 0
  49. else:
  50. # Number of KV heads is less than TP size, so we replicate
  51. # the KV heads across multiple tensor parallel GPUs.
  52. assert tp_size % self.total_num_kv_heads == 0
  53. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  54. self.head_dim = config.hidden_size // self.total_num_heads
  55. self.q_size = self.num_heads * self.head_dim
  56. self.kv_size = self.num_kv_heads * self.head_dim
  57. self.scaling = self.head_dim**-0.5
  58. self.query_key_value = QKVParallelLinear(
  59. self.hidden_size,
  60. self.head_dim,
  61. self.total_num_heads,
  62. self.total_num_kv_heads,
  63. bias=config.add_bias_linear or config.add_qkv_bias,
  64. quant_config=quant_config,
  65. )
  66. self.dense = RowParallelLinear(
  67. self.total_num_heads * self.head_dim,
  68. config.hidden_size,
  69. bias=config.add_bias_linear,
  70. quant_config=quant_config,
  71. )
  72. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  73. rope_ratio = getattr(config, "rope_ratio", 1.0)
  74. max_positions = getattr(config, "seq_length", 8192)
  75. self.rotary_emb = get_rope(
  76. self.head_dim,
  77. rotary_dim=self.head_dim // 2,
  78. max_position=max_positions,
  79. base=10000 * rope_ratio,
  80. is_neox_style=False,
  81. )
  82. self.attn = Attention(self.num_heads,
  83. self.head_dim,
  84. self.scaling,
  85. num_kv_heads=self.num_kv_heads,
  86. cache_config=cache_config,
  87. quant_config=quant_config)
  88. def forward(
  89. self,
  90. hidden_states: torch.Tensor,
  91. position_ids: torch.Tensor,
  92. kv_cache: torch.Tensor,
  93. attn_metadata: AttentionMetadata,
  94. ) -> torch.Tensor:
  95. qkv, _ = self.query_key_value(hidden_states)
  96. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  97. q, k = self.rotary_emb(position_ids, q, k)
  98. context_layer = self.attn(
  99. q,
  100. k,
  101. v,
  102. kv_cache,
  103. attn_metadata,
  104. )
  105. attn_output, _ = self.dense(context_layer)
  106. return attn_output
  107. class GLMMLP(nn.Module):
  108. """MLP.
  109. MLP will take the input with h hidden state, project it to 4*h
  110. hidden dimension, perform nonlinear transformation, and project the
  111. state back into h hidden dimension.
  112. """
  113. def __init__(
  114. self,
  115. config,
  116. quant_config: Optional[QuantizationConfig] = None,
  117. ):
  118. super().__init__()
  119. self.add_bias = config.add_bias_linear
  120. # Project to 4h.
  121. self.dense_h_to_4h = MergedColumnParallelLinear(
  122. config.hidden_size,
  123. [config.ffn_hidden_size] * 2,
  124. bias=config.add_bias_linear,
  125. quant_config=quant_config,
  126. )
  127. self.activation_func = SiluAndMul()
  128. # Project back to h.
  129. self.dense_4h_to_h = RowParallelLinear(
  130. config.ffn_hidden_size,
  131. config.hidden_size,
  132. bias=config.add_bias_linear,
  133. quant_config=quant_config,
  134. )
  135. def forward(self, hidden_states):
  136. # [s, b, 4hp]
  137. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  138. intermediate_parallel = self.activation_func(intermediate_parallel)
  139. # [s, b, h]
  140. output, _ = self.dense_4h_to_h(intermediate_parallel)
  141. return output
  142. class GLMBlock(nn.Module):
  143. """A single transformer layer.
  144. Transformer layer takes input with size [s, b, h] and returns an
  145. output of the same size.
  146. """
  147. def __init__(
  148. self,
  149. config,
  150. cache_config: Optional[CacheConfig] = None,
  151. quant_config: Optional[QuantizationConfig] = None,
  152. ):
  153. super().__init__()
  154. self.apply_residual_connection_post_layernorm = (
  155. config.apply_residual_connection_post_layernorm)
  156. self.fp32_residual_connection = config.fp32_residual_connection
  157. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  158. # Layernorm on the input data.
  159. self.input_layernorm = layer_norm_func(config.hidden_size,
  160. eps=config.layernorm_epsilon)
  161. # Self attention.
  162. self.self_attention = GLMAttention(config, cache_config, quant_config)
  163. self.hidden_dropout = config.hidden_dropout
  164. # Layernorm on the attention output
  165. self.post_attention_layernorm = layer_norm_func(
  166. config.hidden_size, eps=config.layernorm_epsilon)
  167. # MLP
  168. self.mlp = GLMMLP(config, quant_config)
  169. def forward(
  170. self,
  171. hidden_states: torch.Tensor,
  172. position_ids: torch.Tensor,
  173. kv_cache: torch.Tensor,
  174. attn_metadata: AttentionMetadata,
  175. ) -> torch.Tensor:
  176. # hidden_states: [num_tokens, h]
  177. # Layer norm at the beginning of the transformer layer.
  178. layernorm_output = self.input_layernorm(hidden_states)
  179. # Self attention.
  180. attention_output = self.self_attention(
  181. hidden_states=layernorm_output,
  182. position_ids=position_ids,
  183. kv_cache=kv_cache,
  184. attn_metadata=attn_metadata,
  185. )
  186. # Residual connection.
  187. if self.apply_residual_connection_post_layernorm:
  188. residual = layernorm_output
  189. else:
  190. residual = hidden_states
  191. layernorm_input = residual + attention_output
  192. # Layer norm post the self attention.
  193. layernorm_output = self.post_attention_layernorm(layernorm_input)
  194. # Second residual connection.
  195. if self.apply_residual_connection_post_layernorm:
  196. residual = layernorm_output
  197. else:
  198. residual = layernorm_input
  199. output = self.mlp(layernorm_output) + residual
  200. return output
  201. class GLMTransformer(nn.Module):
  202. """Transformer class."""
  203. def __init__(
  204. self,
  205. config,
  206. cache_config: Optional[CacheConfig] = None,
  207. quant_config: Optional[QuantizationConfig] = None,
  208. ):
  209. super().__init__()
  210. self.post_layer_norm = config.post_layer_norm
  211. # Number of layers.
  212. self.num_layers = config.num_layers
  213. # Transformer layers.
  214. self.layers = nn.ModuleList([
  215. GLMBlock(config, cache_config, quant_config)
  216. for i in range(self.num_layers)
  217. ])
  218. if self.post_layer_norm:
  219. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  220. # Final layer norm before output.
  221. self.final_layernorm = layer_norm_func(
  222. config.hidden_size, eps=config.layernorm_epsilon)
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. position_ids: torch.Tensor,
  227. kv_caches: List[torch.Tensor],
  228. attn_metadata: AttentionMetadata,
  229. ) -> torch.Tensor:
  230. for i in range(self.num_layers):
  231. layer = self.layers[i]
  232. hidden_states = layer(
  233. hidden_states=hidden_states,
  234. position_ids=position_ids,
  235. kv_cache=kv_caches[i],
  236. attn_metadata=attn_metadata,
  237. )
  238. # Final layer norm.
  239. if self.post_layer_norm:
  240. hidden_states = self.final_layernorm(hidden_states)
  241. return hidden_states
  242. class ChatGLMModel(nn.Module):
  243. def __init__(
  244. self,
  245. config,
  246. cache_config: Optional[CacheConfig] = None,
  247. quant_config: Optional[QuantizationConfig] = None,
  248. ):
  249. super().__init__()
  250. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  251. config.hidden_size)
  252. self.num_layers = config.num_layers
  253. self.multi_query_group_num = config.multi_query_group_num
  254. self.kv_channels = config.kv_channels
  255. self.encoder = GLMTransformer(config, cache_config, quant_config)
  256. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  257. config.hidden_size,
  258. quant_config=quant_config)
  259. def forward(
  260. self,
  261. input_ids: torch.Tensor,
  262. position_ids: torch.Tensor,
  263. kv_caches: List[torch.Tensor],
  264. attn_metadata: AttentionMetadata,
  265. ) -> torch.Tensor:
  266. inputs_embeds = self.embedding(input_ids)
  267. # Run encoder.
  268. hidden_states = self.encoder(
  269. hidden_states=inputs_embeds,
  270. position_ids=position_ids,
  271. kv_caches=kv_caches,
  272. attn_metadata=attn_metadata,
  273. )
  274. return hidden_states
  275. class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
  276. packed_modules_mapping = {
  277. "query_key_value": ["query_key_value"],
  278. "dense_h_to_4h": ["dense_h_to_4h"]
  279. }
  280. # LoRA specific attributes
  281. supported_lora_modules = [
  282. "query_key_value",
  283. "dense",
  284. "dense_h_to_4h",
  285. "dense_4h_to_h",
  286. ]
  287. embedding_modules = {}
  288. embedding_padding_modules = []
  289. def __init__(
  290. self,
  291. config: ChatGLMConfig,
  292. cache_config: Optional[CacheConfig] = None,
  293. quant_config: Optional[QuantizationConfig] = None,
  294. lora_config: Optional[LoRAConfig] = None,
  295. ):
  296. super().__init__()
  297. self.config = config
  298. self.lora_config = lora_config
  299. self.quant_config = quant_config
  300. self.max_position_embeddings = getattr(config, "max_sequence_length",
  301. 8192)
  302. self.transformer = ChatGLMModel(config, cache_config, quant_config)
  303. self.lm_head = self.transformer.output_layer
  304. self.logits_processor = LogitsProcessor(config.padded_vocab_size)
  305. self.sampler = Sampler()
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[torch.Tensor],
  311. attn_metadata: AttentionMetadata,
  312. intermediate_tensors: Optional[IntermediateTensors] = None,
  313. ) -> torch.Tensor:
  314. hidden_states = self.transformer(input_ids, positions, kv_caches,
  315. attn_metadata)
  316. return hidden_states
  317. def compute_logits(
  318. self,
  319. hidden_states: torch.Tensor,
  320. sampling_metadata: SamplingMetadata,
  321. ) -> Optional[torch.Tensor]:
  322. logits = self.logits_processor(self.lm_head, hidden_states,
  323. sampling_metadata)
  324. return logits
  325. def sample(
  326. self,
  327. logits: torch.Tensor,
  328. sampling_metadata: SamplingMetadata,
  329. ) -> Optional[SamplerOutput]:
  330. next_tokens = self.sampler(logits, sampling_metadata)
  331. return next_tokens
  332. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  333. params_dict = dict(self.named_parameters(remove_duplicate=False))
  334. for name, loaded_weight in weights:
  335. if "rotary_pos_emb.inv_freq" in name:
  336. continue
  337. if "word_embeddings" in name:
  338. name = name.replace(".word_embeddings", "")
  339. # Skip loading extra bias for GPTQ models.
  340. if name.endswith(".bias") and name not in params_dict:
  341. continue
  342. param = params_dict[name]
  343. weight_loader = getattr(param, "weight_loader",
  344. default_weight_loader)
  345. weight_loader(param, loaded_weight)