config.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from typing import Dict, Optional, Type
  3. from loguru import logger
  4. from transformers import PretrainedConfig
  5. from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
  6. JAISConfig, MPTConfig,
  7. RWConfig)
  8. _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
  9. "chatglm": ChatGLMConfig,
  10. "dbrx": DbrxConfig,
  11. "mpt": MPTConfig,
  12. "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
  13. "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
  14. "jais": JAISConfig,
  15. }
  16. def get_config(model: str,
  17. trust_remote_code: bool,
  18. revision: Optional[str] = None,
  19. code_revision: Optional[str] = None,
  20. rope_scaling: Optional[dict] = None) -> PretrainedConfig:
  21. try:
  22. if os.getenv("APHRODITE_USE_MODELSCOPE", "0") == "1":
  23. from modelscope import AutoConfig
  24. else:
  25. from transformers import AutoConfig
  26. config = AutoConfig.from_pretrained(
  27. model,
  28. trust_remote_code=trust_remote_code,
  29. revision=revision,
  30. code_revision=code_revision)
  31. except ValueError as e:
  32. if (not trust_remote_code and
  33. "requires you to execute the configuration file" in str(e)):
  34. err_msg = (
  35. "Failed to load the model config. If the model is a custom "
  36. "model not yet available in the HuggingFace transformers "
  37. "library, consider setting `trust_remote_code=True` in LLM "
  38. "or using the `--trust-remote-code` flag in the CLI.")
  39. raise RuntimeError(err_msg) from e
  40. else:
  41. raise e
  42. if config.model_type in _CONFIG_REGISTRY:
  43. config_class = _CONFIG_REGISTRY[config.model_type]
  44. config = config_class.from_pretrained(model,
  45. revision=revision,
  46. code_revision=code_revision)
  47. if rope_scaling is not None:
  48. logger.info(
  49. "Updating rope_scaling from "
  50. f"{getattr(config, 'rope_scaling', None)} to {rope_scaling}")
  51. config.update({"rope_scaling": rope_scaling})
  52. return config
  53. def get_hf_text_config(config: PretrainedConfig):
  54. """Get the "sub" config relevant to llm for multi modal models.
  55. No op for pure text models.
  56. """
  57. if hasattr(config, "text_config"):
  58. # The code operates under the assumption that text_config should have
  59. # `num_attention_heads` (among others). Assert here to fail early
  60. # if transformers config doesn't align with this assumption.
  61. assert hasattr(config.text_config, "num_attention_heads")
  62. return config.text_config
  63. else:
  64. return config