chatglm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. """Inference-only ChatGLM model compatible with THUDM weights."""
  5. from typing import List, Optional, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import LayerNorm
  9. from aphrodite.modeling.metadata import InputMetadata
  10. from aphrodite.modeling.layers.activation import SiluAndMul
  11. from aphrodite.modeling.layers.attention import PagedAttention
  12. from aphrodite.modeling.layers.layernorm import RMSNorm
  13. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  14. MergedColumnParallelLinear,
  15. QKVParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.modeling.layers.rotary_embedding import get_rope
  18. from aphrodite.modeling.layers.sampler import Sampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. VocabParallelEmbedding, ParallelLMHead)
  21. from aphrodite.modeling.megatron.parallel_state import (
  22. get_tensor_model_parallel_world_size)
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  25. hf_model_weights_iterator)
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.transformers_utils.configs import ChatGLMConfig
  28. KVCache = Tuple[torch.Tensor, torch.Tensor]
  29. class GLMAttention(nn.Module):
  30. def __init__(
  31. self,
  32. config,
  33. linear_method: Optional[LinearMethodBase] = None,
  34. ):
  35. super().__init__()
  36. self.hidden_size = config.hidden_size
  37. tp_size = get_tensor_model_parallel_world_size()
  38. self.total_num_heads = config.num_attention_heads
  39. assert self.total_num_heads % tp_size == 0
  40. self.num_heads = self.total_num_heads // tp_size
  41. self.multi_query_attention = config.multi_query_attention
  42. self.total_num_kv_heads = (config.multi_query_group_num
  43. if config.multi_query_attention else
  44. config.num_attention_heads)
  45. if self.total_num_kv_heads >= tp_size:
  46. # Number of KV heads is greater than TP size, so we partition
  47. # the KV heads across multiple tensor parallel GPUs.
  48. assert self.total_num_kv_heads % tp_size == 0
  49. else:
  50. # Number of KV heads is less than TP size, so we replicate
  51. # the KV heads across multiple tensor parallel GPUs.
  52. assert tp_size % self.total_num_kv_heads == 0
  53. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  54. self.head_dim = config.hidden_size // self.total_num_heads
  55. self.q_size = self.num_heads * self.head_dim
  56. self.kv_size = self.num_kv_heads * self.head_dim
  57. self.scaling = self.head_dim**-0.5
  58. self.query_key_value = QKVParallelLinear(
  59. self.hidden_size,
  60. self.head_dim,
  61. self.total_num_heads,
  62. self.total_num_kv_heads,
  63. bias=config.add_bias_linear or config.add_qkv_bias,
  64. linear_method=linear_method,
  65. )
  66. self.dense = RowParallelLinear(
  67. self.total_num_heads * self.head_dim,
  68. config.hidden_size,
  69. bias=config.add_bias_linear,
  70. linear_method=linear_method,
  71. )
  72. # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
  73. rope_ratio = getattr(config, "rope_ratio", 1.0)
  74. max_positions = getattr(config, "seq_length", 8192)
  75. self.rotary_emb = get_rope(
  76. self.head_dim,
  77. rotary_dim=self.head_dim // 2,
  78. max_position=max_positions,
  79. base=10000 * rope_ratio,
  80. is_neox_style=False,
  81. )
  82. self.attn = PagedAttention(
  83. self.num_heads,
  84. self.head_dim,
  85. self.scaling,
  86. num_kv_heads=self.num_kv_heads,
  87. )
  88. def forward(
  89. self,
  90. hidden_states: torch.Tensor,
  91. position_ids: torch.Tensor,
  92. kv_cache: KVCache,
  93. input_metadata: InputMetadata,
  94. ) -> torch.Tensor:
  95. qkv, _ = self.query_key_value(hidden_states)
  96. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  97. q, k = self.rotary_emb(position_ids, q, k)
  98. key_cache, value_cache = kv_cache
  99. context_layer = self.attn(
  100. q,
  101. k,
  102. v,
  103. key_cache,
  104. value_cache,
  105. input_metadata,
  106. )
  107. attn_output, _ = self.dense(context_layer)
  108. return attn_output
  109. class GLMMLP(nn.Module):
  110. """MLP.
  111. MLP will take the input with h hidden state, project it to 4*h
  112. hidden dimension, perform nonlinear transformation, and project the
  113. state back into h hidden dimension.
  114. """
  115. def __init__(
  116. self,
  117. config,
  118. linear_method: Optional[LinearMethodBase] = None,
  119. ):
  120. super().__init__()
  121. self.add_bias = config.add_bias_linear
  122. # Project to 4h.
  123. self.dense_h_to_4h = MergedColumnParallelLinear(
  124. config.hidden_size,
  125. [config.ffn_hidden_size] * 2,
  126. bias=config.add_bias_linear,
  127. linear_method=linear_method,
  128. )
  129. self.activation_func = SiluAndMul()
  130. # Project back to h.
  131. self.dense_4h_to_h = RowParallelLinear(
  132. config.ffn_hidden_size,
  133. config.hidden_size,
  134. bias=config.add_bias_linear,
  135. linear_method=linear_method,
  136. )
  137. def forward(self, hidden_states):
  138. # [s, b, 4hp]
  139. intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
  140. intermediate_parallel = self.activation_func(intermediate_parallel)
  141. # [s, b, h]
  142. output, _ = self.dense_4h_to_h(intermediate_parallel)
  143. return output
  144. class GLMBlock(nn.Module):
  145. """A single transformer layer.
  146. Transformer layer takes input with size [s, b, h] and returns an
  147. output of the same size.
  148. """
  149. def __init__(
  150. self,
  151. config,
  152. linear_method: Optional[LinearMethodBase] = None,
  153. ):
  154. super().__init__()
  155. self.apply_residual_connection_post_layernorm = (
  156. config.apply_residual_connection_post_layernorm)
  157. self.fp32_residual_connection = config.fp32_residual_connection
  158. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  159. # Layernorm on the input data.
  160. self.input_layernorm = layer_norm_func(config.hidden_size,
  161. eps=config.layernorm_epsilon)
  162. # Self attention.
  163. self.self_attention = GLMAttention(config, linear_method)
  164. self.hidden_dropout = config.hidden_dropout
  165. # Layernorm on the attention output
  166. self.post_attention_layernorm = layer_norm_func(
  167. config.hidden_size, eps=config.layernorm_epsilon)
  168. # MLP
  169. self.mlp = GLMMLP(config, linear_method)
  170. def forward(
  171. self,
  172. hidden_states: torch.Tensor,
  173. position_ids: torch.Tensor,
  174. kv_cache: KVCache,
  175. input_metadata: InputMetadata,
  176. ) -> torch.Tensor:
  177. # hidden_states: [num_tokens, h]
  178. # Layer norm at the beginning of the transformer layer.
  179. layernorm_output = self.input_layernorm(hidden_states)
  180. # Self attention.
  181. attention_output = self.self_attention(
  182. hidden_states=layernorm_output,
  183. position_ids=position_ids,
  184. kv_cache=kv_cache,
  185. input_metadata=input_metadata,
  186. )
  187. # Residual connection.
  188. if self.apply_residual_connection_post_layernorm:
  189. residual = layernorm_output
  190. else:
  191. residual = hidden_states
  192. layernorm_input = residual + attention_output
  193. # Layer norm post the self attention.
  194. layernorm_output = self.post_attention_layernorm(layernorm_input)
  195. # Second residual connection.
  196. if self.apply_residual_connection_post_layernorm:
  197. residual = layernorm_output
  198. else:
  199. residual = layernorm_input
  200. output = self.mlp(layernorm_output) + residual
  201. return output
  202. class GLMTransformer(nn.Module):
  203. """Transformer class."""
  204. def __init__(
  205. self,
  206. config,
  207. linear_method: Optional[LinearMethodBase] = None,
  208. ):
  209. super().__init__()
  210. self.post_layer_norm = config.post_layer_norm
  211. # Number of layers.
  212. self.num_layers = config.num_layers
  213. # Transformer layers.
  214. self.layers = nn.ModuleList(
  215. [GLMBlock(config, linear_method) for i in range(self.num_layers)])
  216. if self.post_layer_norm:
  217. layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
  218. # Final layer norm before output.
  219. self.final_layernorm = layer_norm_func(
  220. config.hidden_size, eps=config.layernorm_epsilon)
  221. def forward(
  222. self,
  223. hidden_states: torch.Tensor,
  224. position_ids: torch.Tensor,
  225. kv_caches: List[KVCache],
  226. input_metadata: InputMetadata,
  227. ) -> torch.Tensor:
  228. for i in range(self.num_layers):
  229. layer = self.layers[i]
  230. hidden_states = layer(
  231. hidden_states=hidden_states,
  232. position_ids=position_ids,
  233. kv_cache=kv_caches[i],
  234. input_metadata=input_metadata,
  235. )
  236. # Final layer norm.
  237. if self.post_layer_norm:
  238. hidden_states = self.final_layernorm(hidden_states)
  239. return hidden_states
  240. class ChatGLMModel(nn.Module):
  241. def __init__(
  242. self,
  243. config,
  244. linear_method: Optional[LinearMethodBase] = None,
  245. ):
  246. super().__init__()
  247. self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
  248. config.hidden_size,
  249. linear_method=linear_method)
  250. self.num_layers = config.num_layers
  251. self.multi_query_group_num = config.multi_query_group_num
  252. self.kv_channels = config.kv_channels
  253. self.encoder = GLMTransformer(config, linear_method)
  254. self.output_layer = ParallelLMHead(config.padded_vocab_size,
  255. config.hidden_size,
  256. linear_method=linear_method)
  257. def forward(
  258. self,
  259. input_ids: torch.Tensor,
  260. position_ids: torch.Tensor,
  261. kv_caches: List[KVCache],
  262. input_metadata: InputMetadata,
  263. ) -> torch.Tensor:
  264. inputs_embeds = self.embedding(input_ids)
  265. # Run encoder.
  266. hidden_states = self.encoder(
  267. hidden_states=inputs_embeds,
  268. position_ids=position_ids,
  269. kv_caches=kv_caches,
  270. input_metadata=input_metadata,
  271. )
  272. return hidden_states
  273. class ChatGLMForCausalLM(nn.Module):
  274. def __init__(
  275. self,
  276. config: ChatGLMConfig,
  277. linear_method: Optional[LinearMethodBase] = None,
  278. ):
  279. super().__init__()
  280. self.config: ChatGLMConfig = config
  281. self.linear_method = linear_method
  282. self.transformer = ChatGLMModel(config, linear_method)
  283. # self.lm_head_weight = self.transformer.output_layer.weight
  284. self.sampler = Sampler(config.padded_vocab_size)
  285. def forward(
  286. self,
  287. input_ids: torch.Tensor,
  288. positions: torch.Tensor,
  289. kv_caches: List[KVCache],
  290. input_metadata: InputMetadata,
  291. ) -> torch.Tensor:
  292. hidden_states = self.transformer(input_ids, positions, kv_caches,
  293. input_metadata)
  294. return hidden_states
  295. def sample(
  296. self,
  297. hidden_states: torch.Tensor,
  298. sampling_metadata: SamplingMetadata,
  299. ) -> Optional[SamplerOutput]:
  300. next_tokens = self.sampler(
  301. self.transformer.output_layer(hidden_states), sampling_metadata)
  302. return next_tokens
  303. def load_weights(self,
  304. model_name_or_path: str,
  305. cache_dir: Optional[str] = None,
  306. load_format: str = "auto",
  307. revision: Optional[str] = None):
  308. params_dict = dict(self.named_parameters(remove_duplicate=False))
  309. for name, loaded_weight in hf_model_weights_iterator(
  310. model_name_or_path, cache_dir, load_format, revision,
  311. self.config):
  312. if "rotary_pos_emb.inv_freq" in name:
  313. continue
  314. if "word_embeddings" in name:
  315. name = name.replace(".word_embeddings", "")
  316. # Skip loading extra bias for GPTQ models.
  317. if name.endswith(".bias") and name not in params_dict:
  318. continue
  319. param = params_dict[name]
  320. weight_loader = getattr(param, "weight_loader",
  321. default_weight_loader)
  322. weight_loader(param, loaded_weight)