eagle.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. from typing import Optional, Union
  3. from transformers import AutoConfig, PretrainedConfig
  4. class EAGLEConfig(PretrainedConfig):
  5. model_type = "eagle"
  6. def __init__(
  7. self,
  8. model: Union[PretrainedConfig, dict, None] = None,
  9. truncated_vocab_size: Optional[int] = None,
  10. **kwargs,
  11. ):
  12. model_config = (
  13. None
  14. if model is None
  15. else (
  16. AutoConfig.for_model(**model)
  17. if isinstance(model, dict)
  18. else model
  19. )
  20. )
  21. for k, v in kwargs.items():
  22. if (
  23. k != "architectures"
  24. and k != "model_type"
  25. and hasattr(model_config, k)
  26. ):
  27. setattr(model_config, k, v)
  28. self.model = model_config
  29. if self.model is None:
  30. self.truncated_vocab_size = None
  31. else:
  32. self.truncated_vocab_size = (
  33. self.model.vocab_size
  34. if truncated_vocab_size is None
  35. else truncated_vocab_size
  36. )
  37. if "architectures" not in kwargs:
  38. kwargs["architectures"] = ["EAGLEModel"]
  39. super().__init__(**kwargs)
  40. if self.model is not None:
  41. for k, v in self.model.to_dict().items():
  42. if not hasattr(self, k):
  43. setattr(self, k, v)
  44. @classmethod
  45. def from_pretrained(
  46. cls,
  47. pretrained_model_name_or_path: Union[str, os.PathLike],
  48. **kwargs,
  49. ) -> "EAGLEConfig":
  50. config_dict, kwargs = cls.get_config_dict(
  51. pretrained_model_name_or_path, **kwargs
  52. )
  53. return cls.from_dict(config_dict, **kwargs)