mlp_speculator.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import List, Optional
  2. from transformers import PretrainedConfig
  3. class MLPSpeculatorConfig(PretrainedConfig):
  4. model_type = "mlp_speculator"
  5. attribute_map = {
  6. "hidden_size": "emb_dim",
  7. }
  8. def __init__(self,
  9. vocab_size: int = 32000,
  10. emb_dim: int = 4096,
  11. inner_dim: int = 0,
  12. n_predict: int = 3,
  13. top_k_tokens_per_head: Optional[List[int]] = None,
  14. n_candidates: int = 5,
  15. tie_weights: bool = False,
  16. scale_input: bool = False,
  17. **kwargs):
  18. """
  19. Initialize an MLPSpeculatorConfig
  20. Args:
  21. vocab_size: int
  22. the model vocab size
  23. emb_dim: int
  24. the model embedding dimension
  25. inner_dim: int
  26. the inner dimension of the model. If 0, will be the emb_dim.
  27. n_predict: int
  28. the number of lookaheads for the speculator
  29. top_k_tokens_per_head: List[int]
  30. Number of tokens to consider from each head when forming the
  31. candidate tree.
  32. For each candidate branch in the tree, head n produces topk[n]
  33. additional sub-branches.
  34. NOTE: This parameter is currently unused.
  35. n_candidates: int
  36. number of child candidates to create per sequence
  37. tie_weights: bool
  38. If true, use a single set of weights for every model
  39. head/stage after the first. The initial projection
  40. from the base model may have a different size, so that
  41. stays separate.
  42. scale_input: bool
  43. if True, will scale the initial hidden states from
  44. the base model.
  45. """
  46. if top_k_tokens_per_head is None:
  47. top_k_tokens_per_head = [5, 4, 3]
  48. assert len(top_k_tokens_per_head) == n_predict
  49. self.vocab_size = vocab_size
  50. self.emb_dim = emb_dim
  51. self.inner_dim = inner_dim
  52. self.n_predict = n_predict
  53. self.top_k_tokens_per_head = top_k_tokens_per_head
  54. self.n_candidates = n_candidates
  55. self.num_lookahead_tokens = n_predict
  56. self.tie_weights = tie_weights
  57. self.scale_input = scale_input
  58. super().__init__(**kwargs)