chatglm.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/THUDM/ChatGLM2-6B
  4. from transformers import PretrainedConfig
  5. class ChatGLMConfig(PretrainedConfig):
  6. model_type = "chatglm"
  7. attribute_map = {
  8. "num_hidden_layers": "num_layers",
  9. "n_head_kv": "multi_query_group_num",
  10. }
  11. def __init__(self,
  12. num_layers=28,
  13. padded_vocab_size=65024,
  14. hidden_size=4096,
  15. ffn_hidden_size=13696,
  16. kv_channels=128,
  17. num_attention_heads=32,
  18. seq_length=2048,
  19. hidden_dropout=0.0,
  20. attention_dropout=0.0,
  21. layernorm_epsilon=1e-5,
  22. rmsnorm=True,
  23. apply_residual_connection_post_layernorm=False,
  24. post_layer_norm=True,
  25. add_bias_linear=False,
  26. add_qkv_bias=False,
  27. interleaved_qkv=False,
  28. bias_dropout_fusion=True,
  29. multi_query_attention=False,
  30. multi_query_group_num=1,
  31. apply_query_key_layer_scaling=True,
  32. attention_softmax_in_fp32=True,
  33. fp32_residual_connection=False,
  34. quantization_bit=0,
  35. pre_seq_len=None,
  36. prefix_projection=False,
  37. **kwargs):
  38. self.num_layers = num_layers
  39. self.vocab_size = padded_vocab_size
  40. self.padded_vocab_size = padded_vocab_size
  41. self.hidden_size = hidden_size
  42. self.ffn_hidden_size = ffn_hidden_size
  43. self.kv_channels = kv_channels
  44. self.num_attention_heads = num_attention_heads
  45. self.seq_length = seq_length
  46. # It is to be compatible with long lora.
  47. self.max_position_embeddings = seq_length
  48. self.hidden_dropout = hidden_dropout
  49. self.attention_dropout = attention_dropout
  50. self.layernorm_epsilon = layernorm_epsilon
  51. self.rmsnorm = rmsnorm
  52. self.apply_residual_connection_post_layernorm = (
  53. apply_residual_connection_post_layernorm)
  54. self.post_layer_norm = post_layer_norm
  55. self.add_bias_linear = add_bias_linear
  56. self.add_qkv_bias = add_qkv_bias
  57. self.bias_dropout_fusion = bias_dropout_fusion
  58. self.multi_query_attention = multi_query_attention
  59. self.multi_query_group_num = multi_query_group_num
  60. self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
  61. self.attention_softmax_in_fp32 = attention_softmax_in_fp32
  62. self.fp32_residual_connection = fp32_residual_connection
  63. self.quantization_bit = quantization_bit
  64. self.pre_seq_len = pre_seq_len
  65. self.prefix_projection = prefix_projection
  66. self.interleaved_qkv = interleaved_qkv
  67. super().__init__(**kwargs)