grok.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from transformers import PretrainedConfig
  2. class GrokConfig(PretrainedConfig):
  3. model_type = "grok"
  4. keys_to_ignore_at_inference = ["past_key_values"]
  5. def __init__(
  6. self,
  7. vocab_size=131072,
  8. hidden_size=6144,
  9. intermediate_size=32768,
  10. num_hidden_layers=64,
  11. num_attention_heads=48,
  12. num_key_value_heads=8,
  13. hidden_act="silu",
  14. max_position_embeddings=8192,
  15. initializer_range=0.02,
  16. rms_norm_eps=1e-5,
  17. use_cache=True,
  18. pad_token_id=0,
  19. bos_token_id=1,
  20. eos_token_id=2,
  21. tie_word_embeddings=True,
  22. rope_theta=1e5,
  23. attention_dropout=0.0,
  24. num_experts_per_tok=2,
  25. num_local_experts=8,
  26. output_router_logits=False,
  27. router_aux_loss_coef=0.001,
  28. output_multiplier_scale=0.5773502691896257,
  29. embedding_multiplier_scale=78.38367176906169,
  30. attn_output_multiplier=0.08838834764831845,
  31. **kwargs,
  32. ):
  33. self.vocab_size = vocab_size
  34. self.max_position_embeddings = max_position_embeddings
  35. self.hidden_size = hidden_size
  36. self.intermediate_size = intermediate_size
  37. self.num_hidden_layers = num_hidden_layers
  38. self.num_attention_heads = num_attention_heads
  39. # for backward compatibility
  40. if num_key_value_heads is None:
  41. num_key_value_heads = num_attention_heads
  42. self.num_key_value_heads = num_key_value_heads
  43. self.hidden_act = hidden_act
  44. self.initializer_range = initializer_range
  45. self.rms_norm_eps = rms_norm_eps
  46. self.use_cache = use_cache
  47. self.rope_theta = rope_theta
  48. self.attention_dropout = attention_dropout
  49. self.num_experts_per_tok = num_experts_per_tok
  50. self.num_local_experts = num_local_experts
  51. self.output_router_logits = output_router_logits
  52. self.router_aux_loss_coef = router_aux_loss_coef
  53. self.output_multiplier_scale = output_multiplier_scale
  54. self.embedding_multiplier_scale = embedding_multiplier_scale
  55. self.attn_output_multiplier = attn_output_multiplier
  56. super().__init__(
  57. pad_token_id=pad_token_id,
  58. bos_token_id=bos_token_id,
  59. eos_token_id=eos_token_id,
  60. tie_word_embeddings=tie_word_embeddings,
  61. **kwargs,
  62. )