config.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Dict, Optional
  2. from transformers import AutoConfig, PretrainedConfig
  3. from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
  4. JAISConfig, MPTConfig,
  5. RWConfig)
  6. _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
  7. "chatglm": ChatGLMConfig,
  8. "dbrx": DbrxConfig,
  9. "mpt": MPTConfig,
  10. "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
  11. "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
  12. "jais": JAISConfig,
  13. }
  14. def get_config(model: str,
  15. trust_remote_code: bool,
  16. revision: Optional[str] = None,
  17. code_revision: Optional[str] = None) -> PretrainedConfig:
  18. try:
  19. config = AutoConfig.from_pretrained(
  20. model,
  21. trust_remote_code=trust_remote_code,
  22. revision=revision,
  23. code_revision=code_revision)
  24. except ValueError as e:
  25. if (not trust_remote_code and
  26. "requires you to execute the configuration file" in str(e)):
  27. err_msg = (
  28. "Failed to load the model config. If the model is a custom "
  29. "model not yet available in the HuggingFace transformers "
  30. "library, consider setting `trust_remote_code=True` in LLM "
  31. "or using the `--trust-remote-code` flag in the CLI.")
  32. raise RuntimeError(err_msg) from e
  33. else:
  34. raise e
  35. if config.model_type in _CONFIG_REGISTRY:
  36. config_class = _CONFIG_REGISTRY[config.model_type]
  37. config = config_class.from_pretrained(model,
  38. revision=revision,
  39. code_revision=code_revision)
  40. return config
  41. def get_hf_text_config(config: PretrainedConfig):
  42. """Get the "sub" config relevant to llm for multi modal models.
  43. No op for pure text models.
  44. """
  45. if hasattr(config, "text_config"):
  46. # The code operates under the assumption that text_config should have
  47. # `num_attention_heads` (among others). Assert here to fail early
  48. # if transformers config doesn't align with this assumption.
  49. assert hasattr(config.text_config, "num_attention_heads")
  50. return config.text_config
  51. else:
  52. return config