chatglm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import Iterable, List, Optional, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, LoRAConfig
  11. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  12. from aphrodite.common.utils import progress_bar
  13. from aphrodite.distributed import get_tensor_model_parallel_world_size
  14. from aphrodite.modeling.layers.activation import SiluAndMul
  15. from aphrodite.modeling.layers.layernorm import RMSNorm
  16. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  20. from aphrodite.modeling.layers.rotary_embedding import get_rope
  21. from aphrodite.modeling.layers.sampler import Sampler
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. ParallelLMHead, VocabParallelEmbedding)
  24. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  25. from aphrodite.modeling.models.interfaces import SupportsLoRA
  26. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  27. from aphrodite.quantization.base_config import QuantizationConfig
  28. from aphrodite.transformers_utils.configs import ChatGLMConfig
  29. class GLMAttention(nn.Module):
  30. def __init__(
  31. self,
  32. config,
  33. cache_config: Optional[CacheConfig] = None,
  34. quant_config: Optional[QuantizationConfig] = None,
  35. ):
  36. super().__init__()
  37. self.hidden_size = config.hidden_size
  38. tp_size = get_tensor_model_parallel_world_size()
  39. self.total_num_heads = config.num_attention_heads
  40. assert self.total_num_heads % tp_size == 0
  41. self.num_heads = self.total_num_heads // tp_size
  42. self.multi_query_attention = config.multi_query_attention
  43. self.total_num_kv_heads = (config.multi_query_group_num
  44. if config.multi_query_attention else
  45. config.num_attention_heads)
  46. if self.total_num_kv_heads >= tp_size:
  47. # Number of KV heads is greater than TP size, so we partition
  48. # the KV heads across multiple tensor parallel GPUs.
  49. assert self.total_num_kv_heads % tp_size == 0
  50. else:
  51. # Number of KV heads is less than TP size, so we replicate
  52. # the KV heads across multiple tensor parallel GPUs.
  53. assert tp_size % self.total_num_kv_heads == 0
  54. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  55. self.head_dim = config.hidden_size // self.total_num_heads
  56. self.q_size = self.num_heads * self.head_dim
  57. self.kv_size = self.num_kv_heads * self.head_dim
  58. self.scaling = self.head_dim**-0.5
  59. self.query_key_value = QKVParallelLinear(
  60. self.hidden_size,
  61. self.head_dim,
  62. self.total_num_heads,
  63. self.total_num_kv_heads,
  64. bias=config.add_bias_linear or config.add_qkv_bias,
  65. quant_config=quant_config,
  66. )
  67. self.dense = RowParallelLinear(
  68. self.total_num_heads * self.head_dim,
  69. config.hidden_size,
  70. bias=config.add_bias_linear,
  71. quant_config=quant_config,
  72. )
  73. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  74. rope_ratio = getattr(config, "rope_ratio", 1.0)
  75. max_positions = getattr(config, "seq_length", 8192)
  76. self.rotary_emb = get_rope(
  77. self.head_dim,
  78. rotary_dim=self.head_dim // 2,
  79. max_position=max_positions,
  80. base=10000 * rope_ratio,
  81. is_neox_style=False,
  82. )
  83. self.attn = Attention(self.num_heads,
  84. self.head_dim,
  85. self.scaling,
  86. num_kv_heads=self.num_kv_heads,
  87. cache_config=cache_config,
  88. quant_config=quant_config)
  89. def forward(
  90. self,
  91. hidden_states: torch.Tensor,
  92. position_ids: torch.Tensor,
  93. kv_cache: torch.Tensor,
  94. attn_metadata: AttentionMetadata,
  95. ) -> torch.Tensor:
  96. qkv, _ = self.query_key_value(hidden_states)
  97. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  98. q, k = self.rotary_emb(position_ids, q, k)
  99. context_layer = self.attn(
  100. q,
  101. k,
  102. v,
  103. kv_cache,
  104. attn_metadata,
  105. )
  106. attn_output, _ = self.dense(context_layer)
  107. return attn_output
  108. class GLMMLP(nn.Module):
  109. """MLP.
  110. MLP will take the input with h hidden state, project it to 4*h
  111. hidden dimension, perform nonlinear transformation, and project the
  112. state back into h hidden dimension.
  113. """
  114. def __init__(
  115. self,
  116. config,
  117. quant_config: Optional[QuantizationConfig] = None,
  118. ):
  119. super().__init__()
  120. self.add_bias = config.add_bias_linear
  121. # Project to 4h.
  122. self.dense_h_to_4h = MergedColumnParallelLinear(
  123. config.hidden_size,
  124. [config.ffn_hidden_size] * 2,
  125. bias=config.add_bias_linear,
  126. quant_config=quant_config,
  127. )
  128. self.activation_func = SiluAndMul()
  129. # Project back to h.
  130. self.dense_4h_to_h = RowParallelLinear(
  131. config.ffn_hidden_size,
  132. config.hidden_size,
  133. bias=config.add_bias_linear,
  134. quant_config=quant_config,
  135. )
  136. def forward(self, hidden_states):
  137. # [s, b, 4hp]
  138. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  139. intermediate_parallel = self.activation_func(intermediate_parallel)
  140. # [s, b, h]
  141. output, _ = self.dense_4h_to_h(intermediate_parallel)
  142. return output
  143. class GLMBlock(nn.Module):
  144. """A single transformer layer.
  145. Transformer layer takes input with size [s, b, h] and returns an
  146. output of the same size.
  147. """
  148. def __init__(
  149. self,
  150. config,
  151. cache_config: Optional[CacheConfig] = None,
  152. quant_config: Optional[QuantizationConfig] = None,
  153. ):
  154. super().__init__()
  155. self.apply_residual_connection_post_layernorm = (
  156. config.apply_residual_connection_post_layernorm)
  157. self.fp32_residual_connection = config.fp32_residual_connection
  158. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  159. # Layernorm on the input data.
  160. self.input_layernorm = layer_norm_func(config.hidden_size,
  161. eps=config.layernorm_epsilon)
  162. # Self attention.
  163. self.self_attention = GLMAttention(config, cache_config, quant_config)
  164. self.hidden_dropout = config.hidden_dropout
  165. # Layernorm on the attention output
  166. self.post_attention_layernorm = layer_norm_func(
  167. config.hidden_size, eps=config.layernorm_epsilon)
  168. # MLP
  169. self.mlp = GLMMLP(config, quant_config)
  170. def forward(
  171. self,
  172. hidden_states: torch.Tensor,
  173. position_ids: torch.Tensor,
  174. kv_cache: torch.Tensor,
  175. attn_metadata: AttentionMetadata,
  176. ) -> torch.Tensor:
  177. # hidden_states: [num_tokens, h]
  178. # Layer norm at the beginning of the transformer layer.
  179. layernorm_output = self.input_layernorm(hidden_states)
  180. # Self attention.
  181. attention_output = self.self_attention(
  182. hidden_states=layernorm_output,
  183. position_ids=position_ids,
  184. kv_cache=kv_cache,
  185. attn_metadata=attn_metadata,
  186. )
  187. # Residual connection.
  188. if self.apply_residual_connection_post_layernorm:
  189. residual = layernorm_output
  190. else:
  191. residual = hidden_states
  192. layernorm_input = residual + attention_output
  193. # Layer norm post the self attention.
  194. layernorm_output = self.post_attention_layernorm(layernorm_input)
  195. # Second residual connection.
  196. if self.apply_residual_connection_post_layernorm:
  197. residual = layernorm_output
  198. else:
  199. residual = layernorm_input
  200. output = self.mlp(layernorm_output) + residual
  201. return output
  202. class GLMTransformer(nn.Module):
  203. """Transformer class."""
  204. def __init__(
  205. self,
  206. config,
  207. cache_config: Optional[CacheConfig] = None,
  208. quant_config: Optional[QuantizationConfig] = None,
  209. ):
  210. super().__init__()
  211. self.post_layer_norm = config.post_layer_norm
  212. # Number of layers.
  213. self.num_layers = config.num_layers
  214. # Transformer layers.
  215. self.layers = nn.ModuleList([
  216. GLMBlock(config, cache_config, quant_config)
  217. for i in range(self.num_layers)
  218. ])
  219. if self.post_layer_norm:
  220. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  221. # Final layer norm before output.
  222. self.final_layernorm = layer_norm_func(
  223. config.hidden_size, eps=config.layernorm_epsilon)
  224. def forward(
  225. self,
  226. hidden_states: torch.Tensor,
  227. position_ids: torch.Tensor,
  228. kv_caches: List[torch.Tensor],
  229. attn_metadata: AttentionMetadata,
  230. ) -> torch.Tensor:
  231. for i in range(self.num_layers):
  232. layer = self.layers[i]
  233. hidden_states = layer(
  234. hidden_states=hidden_states,
  235. position_ids=position_ids,
  236. kv_cache=kv_caches[i],
  237. attn_metadata=attn_metadata,
  238. )
  239. # Final layer norm.
  240. if self.post_layer_norm:
  241. hidden_states = self.final_layernorm(hidden_states)
  242. return hidden_states
  243. class ChatGLMModel(nn.Module):
  244. def __init__(
  245. self,
  246. config,
  247. cache_config: Optional[CacheConfig] = None,
  248. quant_config: Optional[QuantizationConfig] = None,
  249. ):
  250. super().__init__()
  251. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  252. config.hidden_size)
  253. self.num_layers = config.num_layers
  254. self.multi_query_group_num = config.multi_query_group_num
  255. self.kv_channels = config.kv_channels
  256. self.encoder = GLMTransformer(config, cache_config, quant_config)
  257. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  258. config.hidden_size,
  259. quant_config=quant_config)
  260. def forward(
  261. self,
  262. input_ids: torch.Tensor,
  263. position_ids: torch.Tensor,
  264. kv_caches: List[torch.Tensor],
  265. attn_metadata: AttentionMetadata,
  266. ) -> torch.Tensor:
  267. inputs_embeds = self.embedding(input_ids)
  268. # Run encoder.
  269. hidden_states = self.encoder(
  270. hidden_states=inputs_embeds,
  271. position_ids=position_ids,
  272. kv_caches=kv_caches,
  273. attn_metadata=attn_metadata,
  274. )
  275. return hidden_states
  276. class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
  277. packed_modules_mapping = {
  278. "query_key_value": ["query_key_value"],
  279. "dense_h_to_4h": ["dense_h_to_4h"]
  280. }
  281. # LoRA specific attributes
  282. supported_lora_modules = [
  283. "query_key_value",
  284. "dense",
  285. "dense_h_to_4h",
  286. "dense_4h_to_h",
  287. ]
  288. embedding_modules = {}
  289. embedding_padding_modules = []
  290. def __init__(
  291. self,
  292. config: ChatGLMConfig,
  293. cache_config: Optional[CacheConfig] = None,
  294. quant_config: Optional[QuantizationConfig] = None,
  295. lora_config: Optional[LoRAConfig] = None,
  296. ):
  297. super().__init__()
  298. self.config = config
  299. self.lora_config = lora_config
  300. self.quant_config = quant_config
  301. self.max_position_embeddings = getattr(config, "max_sequence_length",
  302. 8192)
  303. self.transformer = ChatGLMModel(config, cache_config, quant_config)
  304. self.lm_head = self.transformer.output_layer
  305. self.logits_processor = LogitsProcessor(config.padded_vocab_size)
  306. self.sampler = Sampler()
  307. def forward(
  308. self,
  309. input_ids: torch.Tensor,
  310. positions: torch.Tensor,
  311. kv_caches: List[torch.Tensor],
  312. attn_metadata: AttentionMetadata,
  313. intermediate_tensors: Optional[IntermediateTensors] = None,
  314. ) -> torch.Tensor:
  315. hidden_states = self.transformer(input_ids, positions, kv_caches,
  316. attn_metadata)
  317. return hidden_states
  318. def compute_logits(
  319. self,
  320. hidden_states: torch.Tensor,
  321. sampling_metadata: SamplingMetadata,
  322. ) -> Optional[torch.Tensor]:
  323. logits = self.logits_processor(self.lm_head, hidden_states,
  324. sampling_metadata)
  325. return logits
  326. def sample(
  327. self,
  328. logits: torch.Tensor,
  329. sampling_metadata: SamplingMetadata,
  330. ) -> Optional[SamplerOutput]:
  331. next_tokens = self.sampler(logits, sampling_metadata)
  332. return next_tokens
  333. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  334. params_dict = dict(self.named_parameters(remove_duplicate=False))
  335. weights_list = list(weights)
  336. for name, loaded_weight in progress_bar(weights_list,
  337. desc="Loading modules..."):
  338. if "rotary_pos_emb.inv_freq" in name:
  339. continue
  340. if "word_embeddings" in name:
  341. name = name.replace(".word_embeddings", "")
  342. # Skip loading extra bias for GPTQ models.
  343. if name.endswith(".bias") and name not in params_dict:
  344. continue
  345. param = params_dict[name]
  346. weight_loader = getattr(param, "weight_loader",
  347. default_weight_loader)
  348. weight_loader(param, loaded_weight)