medusa.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. from typing import Optional, Union
  3. from transformers import PretrainedConfig
  4. class MedusaConfig(PretrainedConfig):
  5. model_type = "medusa"
  6. def __init__(self,
  7. hidden_size: int = 4096,
  8. vocab_size: int = 32001,
  9. num_heads: int = 5,
  10. num_hidden_layers: int = 1,
  11. max_paths: int = 64,
  12. topk: int = 10,
  13. truncated_vocab_size: Optional[int] = None,
  14. **kwargs):
  15. self.hidden_size = hidden_size
  16. self.vocab_size = vocab_size
  17. self.num_heads = num_heads
  18. self.num_hidden_layers = num_hidden_layers
  19. self.max_paths = max_paths
  20. self.topk = topk
  21. self.max_seq_len = int(2**20)
  22. self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
  23. else truncated_vocab_size
  24. if "architectures" not in kwargs:
  25. kwargs["architectures"] = ["MedusaModel"]
  26. super().__init__(**kwargs)
  27. @classmethod
  28. def from_pretrained(
  29. cls,
  30. pretrained_model_name_or_path: Union[str, os.PathLike],
  31. **kwargs,
  32. ) -> "MedusaConfig":
  33. config_dict, kwargs = cls.get_config_dict(
  34. pretrained_model_name_or_path, **kwargs)
  35. for k in list(config_dict.keys()):
  36. if 'num' in k:
  37. if 'heads' in k:
  38. config_dict["num_heads"] = config_dict.pop(k)
  39. elif 'layers' in k:
  40. config_dict["num_hidden_layers"] = config_dict.pop(k)
  41. return cls.from_dict(config_dict, **kwargs)
  42. @property
  43. def num_attention_heads(self):
  44. return 0
  45. @property
  46. def num_lookahead_tokens(self):
  47. return self.num_heads
  48. @num_lookahead_tokens.setter
  49. def num_lookahead_tokens(self, num_lookahead_tokens: int):
  50. self.num_heads = num_lookahead_tokens