12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # coding=utf-8
- # Adapted from
- # https://github.com/THUDM/ChatGLM2-6B
- from transformers import PretrainedConfig
- class ChatGLMConfig(PretrainedConfig):
- model_type = "chatglm"
- attribute_map = {
- "num_hidden_layers": "num_layers",
- "n_head_kv": "multi_query_group_num",
- }
- def __init__(self,
- num_layers=28,
- padded_vocab_size=65024,
- hidden_size=4096,
- ffn_hidden_size=13696,
- kv_channels=128,
- num_attention_heads=32,
- seq_length=2048,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- layernorm_epsilon=1e-5,
- rmsnorm=True,
- apply_residual_connection_post_layernorm=False,
- post_layer_norm=True,
- add_bias_linear=False,
- add_qkv_bias=False,
- interleaved_qkv=False,
- bias_dropout_fusion=True,
- multi_query_attention=False,
- multi_query_group_num=1,
- apply_query_key_layer_scaling=True,
- attention_softmax_in_fp32=True,
- fp32_residual_connection=False,
- quantization_bit=0,
- pre_seq_len=None,
- prefix_projection=False,
- **kwargs):
- self.num_layers = num_layers
- self.vocab_size = padded_vocab_size
- self.padded_vocab_size = padded_vocab_size
- self.hidden_size = hidden_size
- self.ffn_hidden_size = ffn_hidden_size
- self.kv_channels = kv_channels
- self.num_attention_heads = num_attention_heads
- self.seq_length = seq_length
- # It is to be compatible with long lora.
- self.max_position_embeddings = seq_length
- self.hidden_dropout = hidden_dropout
- self.attention_dropout = attention_dropout
- self.layernorm_epsilon = layernorm_epsilon
- self.rmsnorm = rmsnorm
- self.apply_residual_connection_post_layernorm = (
- apply_residual_connection_post_layernorm)
- self.post_layer_norm = post_layer_norm
- self.add_bias_linear = add_bias_linear
- self.add_qkv_bias = add_qkv_bias
- self.bias_dropout_fusion = bias_dropout_fusion
- self.multi_query_attention = multi_query_attention
- self.multi_query_group_num = multi_query_group_num
- self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
- self.fp32_residual_connection = fp32_residual_connection
- self.quantization_bit = quantization_bit
- self.pre_seq_len = pre_seq_len
- self.prefix_projection = prefix_projection
- self.interleaved_qkv = interleaved_qkv
- super().__init__(**kwargs)
|