config.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import contextlib
  2. import os
  3. from typing import Dict, Optional, Type
  4. from loguru import logger
  5. from transformers import GenerationConfig, PretrainedConfig
  6. from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
  7. JAISConfig, MedusaConfig,
  8. MLPSpeculatorConfig,
  9. MPTConfig, RWConfig)
  10. APHRODITE_USE_MODELSCOPE = os.getenv("APHRODITE_USE_MODELSCOPE", "0") == "1"
  11. if APHRODITE_USE_MODELSCOPE:
  12. from modelscope import AutoConfig
  13. else:
  14. from transformers import AutoConfig
  15. _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
  16. "chatglm": ChatGLMConfig,
  17. "dbrx": DbrxConfig,
  18. "mpt": MPTConfig,
  19. "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
  20. "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
  21. "jais": JAISConfig,
  22. "mlp_speculator": MLPSpeculatorConfig,
  23. "medusa": MedusaConfig,
  24. }
  25. for name, cls in _CONFIG_REGISTRY.items():
  26. with contextlib.suppress(ValueError):
  27. AutoConfig.register(name, cls)
  28. def get_config(model: str,
  29. trust_remote_code: bool,
  30. revision: Optional[str] = None,
  31. code_revision: Optional[str] = None,
  32. rope_scaling: Optional[dict] = None,
  33. rope_theta: Optional[float] = None) -> PretrainedConfig:
  34. try:
  35. config = AutoConfig.from_pretrained(
  36. model,
  37. trust_remote_code=trust_remote_code,
  38. revision=revision,
  39. code_revision=code_revision)
  40. except ValueError as e:
  41. if (not trust_remote_code and
  42. "requires you to execute the configuration file" in str(e)):
  43. err_msg = (
  44. "Failed to load the model config. If the model is a custom "
  45. "model not yet available in the HuggingFace transformers "
  46. "library, consider setting `trust_remote_code=True` in LLM "
  47. "or using the `--trust-remote-code` flag in the CLI.")
  48. raise RuntimeError(err_msg) from e
  49. else:
  50. raise e
  51. if config.model_type in _CONFIG_REGISTRY:
  52. config_class = _CONFIG_REGISTRY[config.model_type]
  53. config = config_class.from_pretrained(model,
  54. revision=revision,
  55. code_revision=code_revision)
  56. for key, value in [("rope_scaling", rope_scaling),
  57. ("rope_theta", rope_theta)]:
  58. if value is not None:
  59. logger.info(f"Updating {key} from "
  60. f"{getattr(config, key, None)} to {value}")
  61. config.update({key: value})
  62. return config
  63. def get_hf_text_config(config: PretrainedConfig):
  64. """Get the "sub" config relevant to llm for multi modal models.
  65. No op for pure text models.
  66. """
  67. if hasattr(config, "text_config"):
  68. # The code operates under the assumption that text_config should have
  69. # `num_attention_heads` (among others). Assert here to fail early
  70. # if transformers config doesn't align with this assumption.
  71. assert hasattr(config.text_config, "num_attention_heads")
  72. return config.text_config
  73. else:
  74. return config
  75. def try_get_generation_config(
  76. model: str,
  77. trust_remote_code: bool,
  78. revision: Optional[str] = None,
  79. ) -> Optional[GenerationConfig]:
  80. try:
  81. return GenerationConfig.from_pretrained(
  82. model,
  83. revision=revision,
  84. )
  85. except OSError: # Not found
  86. try:
  87. config = get_config(
  88. model,
  89. trust_remote_code=trust_remote_code,
  90. revision=revision,
  91. )
  92. return GenerationConfig.from_model_config(config)
  93. except OSError: # Not found
  94. return None