jamba.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # ruff: noqa: E501
  2. """Jamba model configuration"""
  3. import math
  4. from transformers.configuration_utils import PretrainedConfig
  5. from transformers import AutoConfig
  6. class JambaConfig(PretrainedConfig):
  7. r"""
  8. Args:
  9. vocab_size (`int`, *optional*, defaults to 65536):
  10. Vocabulary size of the Jurassic model. Defines the number of different tokens that can be represented by the
  11. `inputs_ids` passed when calling [`JurassicModel`]
  12. hidden_size (`int`, *optional*, defaults to 4096):
  13. Dimension of the hidden representations.
  14. intermediate_size (`int`, *optional*, defaults to 14336):
  15. Dimension of the MLP representations.
  16. num_hidden_layers (`int`, *optional*, defaults to 32):
  17. Number of hidden layers in the Transformer encoder.
  18. num_attention_heads (`int`, *optional*, defaults to 32):
  19. Number of attention heads for each attention layer in the Transformer encoder.
  20. num_key_value_heads (`int`, *optional*, defaults to 8):
  21. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  22. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  23. `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  24. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  25. by meanpooling all the original heads within that group. For more details checkout [this
  26. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
  27. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  28. The non-linear activation function (function or string) in the decoder.
  29. # max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
  30. # The maximum sequence length that this model might ever be used with. Jurassic's sliding window attention
  31. # allows sequence of up to 4096*32 tokens.
  32. initializer_range (`float`, *optional*, defaults to 0.02):
  33. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  34. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  35. The epsilon used by the rms normalization layers.
  36. use_cache (`bool`, *optional*, defaults to `True`):
  37. Whether or not the model should return the last key/values attentions (not used by all models). Only
  38. relevant if `config.is_decoder=True`.
  39. pad_token_id (`int`, *optional*):
  40. The id of the padding token.
  41. bos_token_id (`int`, *optional*, defaults to 1):
  42. The id of the "beginning-of-sequence" token.
  43. eos_token_id (`int`, *optional*, defaults to 2):
  44. The id of the "end-of-sequence" token.
  45. use_positional_embeddings (`bool`, *optional, default False)
  46. flag indicating whether to use positional embeddings or not
  47. rope_theta (`float`, *optional*, defaults to 1000000.0):
  48. The base period of the RoPE embeddings.
  49. sliding_window (`int`, *optional*):
  50. Sliding window attention window size. If not specified, will default to `4096`.
  51. num_experts_per_tok (`int`, *optional*, defaults to 2):
  52. The number of experts to root per-token, can be also interpreted as the `top-p` routing
  53. parameter
  54. num_experts (`int`, *optional*, defaults to 16):
  55. Number of experts per Sparse MLP layer.
  56. expert_layer_period (`int`, *optional*, defaults to 2)
  57. Once in this many layers, we will have an expert layer
  58. expert_layer_offset(`int`, *optional*, defaults to 1)
  59. The first layer index that contains an expert mlp layer
  60. attn_layer_period (`int`, *optional*, defaults to 8)
  61. Once in this many layers, we will have a vanilla attention layer
  62. attn_layer_offset(`int`, *optional*, defaults to 4)
  63. The first layer index that contains a vanilla attention mlp layer
  64. """
  65. model_type = "jamba"
  66. keys_to_ignore_at_inference = ["past_key_values"]
  67. def __init__(
  68. self,
  69. vocab_size=65536,
  70. tie_word_embeddings=False,
  71. hidden_size=4096,
  72. intermediate_size=14336,
  73. num_hidden_layers=32,
  74. num_attention_heads=32,
  75. num_key_value_heads=8,
  76. hidden_act="silu",
  77. initializer_range=0.02,
  78. rms_norm_eps=1e-6,
  79. use_cache=True,
  80. output_router_logits=False,
  81. router_aux_loss_coef=0.001,
  82. pad_token_id=0,
  83. bos_token_id=1,
  84. eos_token_id=2,
  85. sliding_window=None,
  86. attention_dropout=0.0,
  87. num_experts_per_tok=2,
  88. num_experts=16,
  89. expert_layer_offset=1,
  90. expert_layer_period=2,
  91. attn_layer_period=8,
  92. attn_layer_offset=4,
  93. use_mamba_kernels=True,
  94. mamba_d_state=16,
  95. mamba_d_conv=4,
  96. mamba_expand=2,
  97. mamba_dt_rank="auto",
  98. mamba_conv_bias=True,
  99. mamba_proj_bias=False,
  100. mamba_inner_layernorms=True,
  101. **kwargs,
  102. ):
  103. self.vocab_size = vocab_size
  104. self.tie_word_embeddings = tie_word_embeddings
  105. self.hidden_size = hidden_size
  106. self.intermediate_size = intermediate_size
  107. self.num_hidden_layers = num_hidden_layers
  108. self.num_attention_heads = num_attention_heads
  109. self.sliding_window = sliding_window
  110. self.attention_dropout = attention_dropout
  111. # for backward compatibility
  112. if num_key_value_heads is None:
  113. num_key_value_heads = num_attention_heads
  114. self.num_key_value_heads = num_key_value_heads
  115. self.hidden_act = hidden_act
  116. self.initializer_range = initializer_range
  117. self.rms_norm_eps = rms_norm_eps
  118. self.use_cache = use_cache
  119. self.output_router_logits = output_router_logits
  120. self.router_aux_loss_coef = router_aux_loss_coef
  121. self.num_experts_per_tok = num_experts_per_tok
  122. self.num_experts = num_experts
  123. self.expert_layer_period = expert_layer_period
  124. self.expert_layer_offset = expert_layer_offset
  125. self.attn_layer_period = attn_layer_period
  126. self.attn_layer_offset = attn_layer_offset
  127. self.use_mamba_kernels = use_mamba_kernels
  128. self.mamba_d_state = mamba_d_state
  129. self.mamba_d_conv = mamba_d_conv
  130. self.mamba_expand = mamba_expand
  131. self.mamba_dt_rank = math.ceil(
  132. self.hidden_size /
  133. 16) if mamba_dt_rank == "auto" else mamba_dt_rank
  134. self.mamba_conv_bias = mamba_conv_bias
  135. self.mamba_proj_bias = mamba_proj_bias
  136. self.mamba_inner_layernorms = mamba_inner_layernorms
  137. super().__init__(
  138. pad_token_id=pad_token_id,
  139. bos_token_id=bos_token_id,
  140. eos_token_id=eos_token_id,
  141. tie_word_embeddings=tie_word_embeddings,
  142. **kwargs,
  143. )
  144. AutoConfig.register('jamba', JambaConfig)