mlp_speculator.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from typing import List, Optional
  2. from transformers import PretrainedConfig
  3. class MLPSpeculatorConfig(PretrainedConfig):
  4. model_type = "mlp_speculator"
  5. attribute_map = {
  6. "hidden_size": "emb_dim",
  7. }
  8. def __init__(self,
  9. vocab_size: int = 32000,
  10. emb_dim: int = 4096,
  11. inner_dim: int = 0,
  12. n_predict: int = 3,
  13. top_k_tokens_per_head: Optional[List[int]] = None,
  14. n_candidates: int = 5,
  15. **kwargs):
  16. """
  17. Initialize an MLPSpeculatorConfig
  18. Args:
  19. vocab_size: int
  20. the model vocab size
  21. emb_dim: int
  22. the model embedding dimension
  23. inner_dim: int
  24. the inner dimension of the model. If 0, will be the emb_dim.
  25. n_predict: int
  26. the number of lookaheads for the speculator
  27. top_k_tokens_per_head: List[int]
  28. Number of tokens to consider from each head when forming the
  29. candidate tree.
  30. For each candidate branch in the tree, head n produces topk[n]
  31. additional sub-branches.
  32. NOTE: This parameter is currently unused.
  33. n_candidates: int
  34. number of child candidates to create per sequence
  35. """
  36. if top_k_tokens_per_head is None:
  37. top_k_tokens_per_head = [5, 4, 3]
  38. assert len(top_k_tokens_per_head) == n_predict
  39. self.vocab_size = vocab_size
  40. self.emb_dim = emb_dim
  41. self.inner_dim = inner_dim
  42. self.n_predict = n_predict
  43. self.top_k_tokens_per_head = top_k_tokens_per_head
  44. self.n_candidates = n_candidates
  45. self.num_lookahead_tokens = n_predict
  46. super().__init__(**kwargs)