dbrx.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # yapf: disable
  2. # ruff: noqa: E501
  3. # coding=utf-8
  4. # Copied from
  5. # https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
  6. """Dbrx configuration."""
  7. from typing import Any, Optional
  8. from loguru import logger
  9. from transformers.configuration_utils import PretrainedConfig
  10. DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
  11. class DbrxAttentionConfig(PretrainedConfig):
  12. """Configuration class for Dbrx Attention.
  13. [`DbrxAttention`] class. It is used to instantiate attention layers
  14. according to the specified arguments, defining the layers architecture.
  15. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  16. documentation from [`PretrainedConfig`] for more information.
  17. Args:
  18. attn_pdrop (`float`, *optional*, defaults to 0.0):
  19. The dropout probability for the attention layers.
  20. clip_qkv (`float`, *optional*, defaults to None):
  21. If not `None`, clip the queries, keys, and values in the attention layer to this value.
  22. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
  23. rope_theta (float): The base frequency for rope.
  24. """
  25. def __init__(
  26. self,
  27. attn_pdrop: float = 0,
  28. clip_qkv: Optional[float] = None,
  29. kv_n_heads: int = 1,
  30. rope_theta: float = 10000.0,
  31. **kwargs: Any,
  32. ):
  33. super().__init__(**kwargs)
  34. self.attn_pdrop = attn_pdrop
  35. self.clip_qkv = clip_qkv
  36. self.kv_n_heads = kv_n_heads
  37. self.rope_theta = rope_theta
  38. for k in ["model_type"]:
  39. if k in kwargs:
  40. kwargs.pop(k)
  41. if len(kwargs) != 0:
  42. raise ValueError(f"Found unknown {kwargs=}")
  43. @classmethod
  44. def from_pretrained(
  45. cls, pretrained_model_name_or_path: str, **kwargs: Any
  46. ) -> "PretrainedConfig":
  47. cls._set_token_in_kwargs(kwargs)
  48. config_dict, kwargs = cls.get_config_dict(
  49. pretrained_model_name_or_path, **kwargs
  50. )
  51. if config_dict.get("model_type") == "dbrx":
  52. config_dict = config_dict["attn_config"]
  53. if (
  54. "model_type" in config_dict
  55. and hasattr(cls, "model_type")
  56. and config_dict["model_type"] != cls.model_type
  57. ):
  58. logger.warning(
  59. f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
  60. + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
  61. )
  62. return cls.from_dict(config_dict, **kwargs)
  63. class DbrxFFNConfig(PretrainedConfig):
  64. """Configuration class for Dbrx FFN.
  65. [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
  66. the specified arguments, defining the layers architecture.
  67. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  68. documentation from [`PretrainedConfig`] for more information.
  69. Args:
  70. ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
  71. The dict should have a key 'name' with the value being the name of
  72. the activation function along with any additional keyword arguments.
  73. ffn_hidden_size (int, optional): The hidden size of the feedforward network.
  74. moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
  75. moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
  76. moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
  77. moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
  78. moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
  79. uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
  80. This should only be used for benchmarking purposes.
  81. """
  82. def __init__(
  83. self,
  84. ffn_act_fn: Optional[dict] = None,
  85. ffn_hidden_size: int = 3584,
  86. moe_num_experts: int = 4,
  87. moe_top_k: int = 1,
  88. moe_jitter_eps: Optional[float] = None,
  89. moe_loss_weight: float = 0.01,
  90. moe_normalize_expert_weights: Optional[float] = 1,
  91. uniform_expert_assignment: bool = False,
  92. **kwargs: Any,
  93. ):
  94. super().__init__()
  95. if ffn_act_fn is None:
  96. ffn_act_fn = {"name": "silu"}
  97. self.ffn_act_fn = ffn_act_fn
  98. self.ffn_hidden_size = ffn_hidden_size
  99. self.moe_num_experts = moe_num_experts
  100. self.moe_top_k = moe_top_k
  101. self.moe_jitter_eps = moe_jitter_eps
  102. self.moe_loss_weight = moe_loss_weight
  103. self.moe_normalize_expert_weights = moe_normalize_expert_weights
  104. self.uniform_expert_assignment = uniform_expert_assignment
  105. for k in ["model_type"]:
  106. if k in kwargs:
  107. kwargs.pop(k)
  108. if len(kwargs) != 0:
  109. raise ValueError(f"Found unknown {kwargs=}")
  110. @classmethod
  111. def from_pretrained(
  112. cls, pretrained_model_name_or_path: str, **kwargs: Any
  113. ) -> "PretrainedConfig":
  114. cls._set_token_in_kwargs(kwargs)
  115. config_dict, kwargs = cls.get_config_dict(
  116. pretrained_model_name_or_path, **kwargs
  117. )
  118. if config_dict.get("model_type") == "dbrx":
  119. config_dict = config_dict["ffn_config"]
  120. if (
  121. "model_type" in config_dict
  122. and hasattr(cls, "model_type")
  123. and config_dict["model_type"] != cls.model_type
  124. ):
  125. logger.warning(
  126. f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
  127. + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
  128. )
  129. return cls.from_dict(config_dict, **kwargs)
  130. class DbrxConfig(PretrainedConfig):
  131. """Configuration class for Dbrx.
  132. [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
  133. specified arguments, defining the model architecture.
  134. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  135. documentation from [`PretrainedConfig`] for more information.
  136. Args:
  137. d_model (`int`, *optional*, defaults to 6144):
  138. Dimensionality of the embeddings and hidden states.
  139. n_heads (`int`, *optional*, defaults to 48):
  140. Number of attention heads for each attention layer in the Transformer encoder.
  141. n_layers (`int`, *optional*, defaults to 40):
  142. Number of hidden layers in the Transformer encoder.
  143. max_seq_len (`int`, *optional*, defaults to 32768):
  144. The maximum sequence length of the model.
  145. vocab_size (`int`, *optional*, defaults to 100352):
  146. Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
  147. the `inputs_ids` passed when calling [`DbrxModel`].
  148. resid_pdrop (`float`, *optional*, defaults to 0.0):
  149. The dropout probability applied to the attention output before combining with residual.
  150. emb_pdrop (`float`, *optional*, defaults to 0.0):
  151. The dropout probability for the embedding layer.
  152. attn_config (`dict`, *optional*):
  153. A dictionary used to configure the model's attention module.
  154. ffn_config (`dict`, *optional*):
  155. A dictionary used to configure the model's FFN module.
  156. use_cache (`bool`, *optional*, defaults to `False`):
  157. Whether or not the model should return the last key/values attentions (not used by all models).
  158. initializer_range (`float`, *optional*, defaults to 0.02):
  159. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  160. output_router_logits (`bool`, *optional*, defaults to `False`):
  161. Whether or not the router logits should be returned by the model. Enabling this will also
  162. allow the model to output the auxiliary loss. See [here]() for more details
  163. router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
  164. The aux loss factor for the total loss.
  165. Example:
  166. ```python
  167. >>> from transformers import DbrxConfig, DbrxModel
  168. >>> # Initializing a Dbrx configuration
  169. >>> configuration = DbrxConfig()
  170. >>> # Initializing a model (with random weights) from the configuration
  171. >>> model = DbrxModel(configuration)
  172. >>> # Accessing the model configuration
  173. >>> configuration = model.config
  174. ```
  175. """
  176. model_type = "dbrx"
  177. attribute_map = {
  178. "num_attention_heads": "n_heads",
  179. "hidden_size": "d_model",
  180. "num_hidden_layers": "n_layers",
  181. "max_position_embeddings": "max_seq_len",
  182. }
  183. def __init__(
  184. self,
  185. d_model: int = 2048,
  186. n_heads: int = 16,
  187. n_layers: int = 24,
  188. max_seq_len: int = 2048,
  189. vocab_size: int = 32000,
  190. resid_pdrop: float = 0.0,
  191. emb_pdrop: float = 0.0,
  192. attn_config: Optional[DbrxAttentionConfig] = None,
  193. ffn_config: Optional[DbrxFFNConfig] = None,
  194. use_cache: bool = True,
  195. initializer_range: float = 0.02,
  196. output_router_logits: bool = False,
  197. router_aux_loss_coef: float = 0.05,
  198. **kwargs: Any,
  199. ):
  200. if attn_config is None:
  201. self.attn_config = DbrxAttentionConfig()
  202. elif isinstance(attn_config, dict):
  203. self.attn_config = DbrxAttentionConfig(**attn_config)
  204. else:
  205. self.attn_config = attn_config
  206. if ffn_config is None:
  207. self.ffn_config = DbrxFFNConfig()
  208. elif isinstance(ffn_config, dict):
  209. self.ffn_config = DbrxFFNConfig(**ffn_config)
  210. else:
  211. self.ffn_config = ffn_config
  212. self.d_model = d_model
  213. self.n_heads = n_heads
  214. self.n_layers = n_layers
  215. self.max_seq_len = max_seq_len
  216. self.vocab_size = vocab_size
  217. self.resid_pdrop = resid_pdrop
  218. self.emb_pdrop = emb_pdrop
  219. self.use_cache = use_cache
  220. self.initializer_range = initializer_range
  221. self.output_router_logits = output_router_logits
  222. self.router_aux_loss_coef = router_aux_loss_coef
  223. tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
  224. if tie_word_embeddings:
  225. raise ValueError(
  226. "tie_word_embeddings is not supported for Dbrx models."
  227. )
  228. super().__init__(
  229. tie_word_embeddings=tie_word_embeddings,
  230. **kwargs,
  231. )