1
0

chatglm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import Iterable, List, Optional, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, LoRAConfig
  11. from aphrodite.common.sequence import SamplerOutput
  12. from aphrodite.distributed import get_tensor_model_parallel_world_size
  13. from aphrodite.modeling.layers.activation import SiluAndMul
  14. from aphrodite.modeling.layers.layernorm import RMSNorm
  15. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  19. from aphrodite.modeling.layers.rotary_embedding import get_rope
  20. from aphrodite.modeling.layers.sampler import Sampler
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  22. ParallelLMHead, VocabParallelEmbedding)
  23. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  24. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  25. from aphrodite.quantization.base_config import QuantizationConfig
  26. from aphrodite.transformers_utils.configs import ChatGLMConfig
  27. class GLMAttention(nn.Module):
  28. def __init__(
  29. self,
  30. config,
  31. cache_config: Optional[CacheConfig] = None,
  32. quant_config: Optional[QuantizationConfig] = None,
  33. ):
  34. super().__init__()
  35. self.hidden_size = config.hidden_size
  36. tp_size = get_tensor_model_parallel_world_size()
  37. self.total_num_heads = config.num_attention_heads
  38. assert self.total_num_heads % tp_size == 0
  39. self.num_heads = self.total_num_heads // tp_size
  40. self.multi_query_attention = config.multi_query_attention
  41. self.total_num_kv_heads = (config.multi_query_group_num
  42. if config.multi_query_attention else
  43. config.num_attention_heads)
  44. if self.total_num_kv_heads >= tp_size:
  45. # Number of KV heads is greater than TP size, so we partition
  46. # the KV heads across multiple tensor parallel GPUs.
  47. assert self.total_num_kv_heads % tp_size == 0
  48. else:
  49. # Number of KV heads is less than TP size, so we replicate
  50. # the KV heads across multiple tensor parallel GPUs.
  51. assert tp_size % self.total_num_kv_heads == 0
  52. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  53. self.head_dim = config.hidden_size // self.total_num_heads
  54. self.q_size = self.num_heads * self.head_dim
  55. self.kv_size = self.num_kv_heads * self.head_dim
  56. self.scaling = self.head_dim**-0.5
  57. self.query_key_value = QKVParallelLinear(
  58. self.hidden_size,
  59. self.head_dim,
  60. self.total_num_heads,
  61. self.total_num_kv_heads,
  62. bias=config.add_bias_linear or config.add_qkv_bias,
  63. quant_config=quant_config,
  64. )
  65. self.dense = RowParallelLinear(
  66. self.total_num_heads * self.head_dim,
  67. config.hidden_size,
  68. bias=config.add_bias_linear,
  69. quant_config=quant_config,
  70. )
  71. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  72. rope_ratio = getattr(config, "rope_ratio", 1.0)
  73. max_positions = getattr(config, "seq_length", 8192)
  74. self.rotary_emb = get_rope(
  75. self.head_dim,
  76. rotary_dim=self.head_dim // 2,
  77. max_position=max_positions,
  78. base=10000 * rope_ratio,
  79. is_neox_style=False,
  80. )
  81. self.attn = Attention(self.num_heads,
  82. self.head_dim,
  83. self.scaling,
  84. num_kv_heads=self.num_kv_heads,
  85. cache_config=cache_config,
  86. quant_config=quant_config)
  87. def forward(
  88. self,
  89. hidden_states: torch.Tensor,
  90. position_ids: torch.Tensor,
  91. kv_cache: torch.Tensor,
  92. attn_metadata: AttentionMetadata,
  93. ) -> torch.Tensor:
  94. qkv, _ = self.query_key_value(hidden_states)
  95. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  96. q, k = self.rotary_emb(position_ids, q, k)
  97. context_layer = self.attn(
  98. q,
  99. k,
  100. v,
  101. kv_cache,
  102. attn_metadata,
  103. )
  104. attn_output, _ = self.dense(context_layer)
  105. return attn_output
  106. class GLMMLP(nn.Module):
  107. """MLP.
  108. MLP will take the input with h hidden state, project it to 4*h
  109. hidden dimension, perform nonlinear transformation, and project the
  110. state back into h hidden dimension.
  111. """
  112. def __init__(
  113. self,
  114. config,
  115. quant_config: Optional[QuantizationConfig] = None,
  116. ):
  117. super().__init__()
  118. self.add_bias = config.add_bias_linear
  119. # Project to 4h.
  120. self.dense_h_to_4h = MergedColumnParallelLinear(
  121. config.hidden_size,
  122. [config.ffn_hidden_size] * 2,
  123. bias=config.add_bias_linear,
  124. quant_config=quant_config,
  125. )
  126. self.activation_func = SiluAndMul()
  127. # Project back to h.
  128. self.dense_4h_to_h = RowParallelLinear(
  129. config.ffn_hidden_size,
  130. config.hidden_size,
  131. bias=config.add_bias_linear,
  132. quant_config=quant_config,
  133. )
  134. def forward(self, hidden_states):
  135. # [s, b, 4hp]
  136. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  137. intermediate_parallel = self.activation_func(intermediate_parallel)
  138. # [s, b, h]
  139. output, _ = self.dense_4h_to_h(intermediate_parallel)
  140. return output
  141. class GLMBlock(nn.Module):
  142. """A single transformer layer.
  143. Transformer layer takes input with size [s, b, h] and returns an
  144. output of the same size.
  145. """
  146. def __init__(
  147. self,
  148. config,
  149. cache_config: Optional[CacheConfig] = None,
  150. quant_config: Optional[QuantizationConfig] = None,
  151. ):
  152. super().__init__()
  153. self.apply_residual_connection_post_layernorm = (
  154. config.apply_residual_connection_post_layernorm)
  155. self.fp32_residual_connection = config.fp32_residual_connection
  156. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  157. # Layernorm on the input data.
  158. self.input_layernorm = layer_norm_func(config.hidden_size,
  159. eps=config.layernorm_epsilon)
  160. # Self attention.
  161. self.self_attention = GLMAttention(config, cache_config, quant_config)
  162. self.hidden_dropout = config.hidden_dropout
  163. # Layernorm on the attention output
  164. self.post_attention_layernorm = layer_norm_func(
  165. config.hidden_size, eps=config.layernorm_epsilon)
  166. # MLP
  167. self.mlp = GLMMLP(config, quant_config)
  168. def forward(
  169. self,
  170. hidden_states: torch.Tensor,
  171. position_ids: torch.Tensor,
  172. kv_cache: torch.Tensor,
  173. attn_metadata: AttentionMetadata,
  174. ) -> torch.Tensor:
  175. # hidden_states: [num_tokens, h]
  176. # Layer norm at the beginning of the transformer layer.
  177. layernorm_output = self.input_layernorm(hidden_states)
  178. # Self attention.
  179. attention_output = self.self_attention(
  180. hidden_states=layernorm_output,
  181. position_ids=position_ids,
  182. kv_cache=kv_cache,
  183. attn_metadata=attn_metadata,
  184. )
  185. # Residual connection.
  186. if self.apply_residual_connection_post_layernorm:
  187. residual = layernorm_output
  188. else:
  189. residual = hidden_states
  190. layernorm_input = residual + attention_output
  191. # Layer norm post the self attention.
  192. layernorm_output = self.post_attention_layernorm(layernorm_input)
  193. # Second residual connection.
  194. if self.apply_residual_connection_post_layernorm:
  195. residual = layernorm_output
  196. else:
  197. residual = layernorm_input
  198. output = self.mlp(layernorm_output) + residual
  199. return output
  200. class GLMTransformer(nn.Module):
  201. """Transformer class."""
  202. def __init__(
  203. self,
  204. config,
  205. cache_config: Optional[CacheConfig] = None,
  206. quant_config: Optional[QuantizationConfig] = None,
  207. ):
  208. super().__init__()
  209. self.post_layer_norm = config.post_layer_norm
  210. # Number of layers.
  211. self.num_layers = config.num_layers
  212. # Transformer layers.
  213. self.layers = nn.ModuleList([
  214. GLMBlock(config, cache_config, quant_config)
  215. for i in range(self.num_layers)
  216. ])
  217. if self.post_layer_norm:
  218. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  219. # Final layer norm before output.
  220. self.final_layernorm = layer_norm_func(
  221. config.hidden_size, eps=config.layernorm_epsilon)
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. position_ids: torch.Tensor,
  226. kv_caches: List[torch.Tensor],
  227. attn_metadata: AttentionMetadata,
  228. ) -> torch.Tensor:
  229. for i in range(self.num_layers):
  230. layer = self.layers[i]
  231. hidden_states = layer(
  232. hidden_states=hidden_states,
  233. position_ids=position_ids,
  234. kv_cache=kv_caches[i],
  235. attn_metadata=attn_metadata,
  236. )
  237. # Final layer norm.
  238. if self.post_layer_norm:
  239. hidden_states = self.final_layernorm(hidden_states)
  240. return hidden_states
  241. class ChatGLMModel(nn.Module):
  242. def __init__(
  243. self,
  244. config,
  245. cache_config: Optional[CacheConfig] = None,
  246. quant_config: Optional[QuantizationConfig] = None,
  247. ):
  248. super().__init__()
  249. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  250. config.hidden_size)
  251. self.num_layers = config.num_layers
  252. self.multi_query_group_num = config.multi_query_group_num
  253. self.kv_channels = config.kv_channels
  254. self.encoder = GLMTransformer(config, cache_config, quant_config)
  255. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  256. config.hidden_size)
  257. def forward(
  258. self,
  259. input_ids: torch.Tensor,
  260. position_ids: torch.Tensor,
  261. kv_caches: List[torch.Tensor],
  262. attn_metadata: AttentionMetadata,
  263. ) -> torch.Tensor:
  264. inputs_embeds = self.embedding(input_ids)
  265. # Run encoder.
  266. hidden_states = self.encoder(
  267. hidden_states=inputs_embeds,
  268. position_ids=position_ids,
  269. kv_caches=kv_caches,
  270. attn_metadata=attn_metadata,
  271. )
  272. return hidden_states
  273. class ChatGLMForCausalLM(nn.Module):
  274. packed_modules_mapping = {
  275. "query_key_value": ["query_key_value"],
  276. "dense_h_to_4h": ["dense_h_to_4h"]
  277. }
  278. # LoRA specific attributes
  279. supported_lora_modules = [
  280. "query_key_value",
  281. "dense",
  282. "dense_h_to_4h",
  283. "dense_4h_to_h",
  284. ]
  285. embedding_modules = {}
  286. embedding_padding_modules = []
  287. def __init__(
  288. self,
  289. config: ChatGLMConfig,
  290. cache_config: Optional[CacheConfig] = None,
  291. quant_config: Optional[QuantizationConfig] = None,
  292. lora_config: Optional[LoRAConfig] = None,
  293. ):
  294. super().__init__()
  295. self.config: ChatGLMConfig = config
  296. self.quant_config = quant_config
  297. self.max_position_embeddings = getattr(config, "max_sequence_length",
  298. 8192)
  299. self.transformer = ChatGLMModel(config, cache_config, quant_config)
  300. self.lm_head_weight = self.transformer.output_layer.weight
  301. self.logits_processor = LogitsProcessor(config.padded_vocab_size)
  302. self.sampler = Sampler()
  303. def forward(
  304. self,
  305. input_ids: torch.Tensor,
  306. positions: torch.Tensor,
  307. kv_caches: List[torch.Tensor],
  308. attn_metadata: AttentionMetadata,
  309. ) -> torch.Tensor:
  310. hidden_states = self.transformer(input_ids, positions, kv_caches,
  311. attn_metadata)
  312. return hidden_states
  313. def compute_logits(self, hidden_states: torch.Tensor,
  314. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  315. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  316. sampling_metadata)
  317. return logits
  318. def sample(
  319. self,
  320. logits: torch.Tensor,
  321. sampling_metadata: SamplingMetadata,
  322. ) -> Optional[SamplerOutput]:
  323. next_tokens = self.sampler(logits, sampling_metadata)
  324. return next_tokens
  325. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  326. params_dict = dict(self.named_parameters(remove_duplicate=False))
  327. for name, loaded_weight in weights:
  328. if "rotary_pos_emb.inv_freq" in name:
  329. continue
  330. if "word_embeddings" in name:
  331. name = name.replace(".word_embeddings", "")
  332. # Skip loading extra bias for GPTQ models.
  333. if name.endswith(".bias") and name not in params_dict:
  334. continue
  335. param = params_dict[name]
  336. weight_loader = getattr(param, "weight_loader",
  337. default_weight_loader)
  338. weight_loader(param, loaded_weight)