1
0

config.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import contextlib
  2. import os
  3. from pathlib import Path
  4. from typing import Dict, Optional, Type, Union
  5. from loguru import logger
  6. from transformers import GenerationConfig, PretrainedConfig
  7. from transformers.models.auto.modeling_auto import (
  8. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  9. from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
  10. InternVLChatConfig,
  11. JAISConfig, MedusaConfig,
  12. MLPSpeculatorConfig,
  13. MPTConfig, RWConfig)
  14. from aphrodite.transformers_utils.utils import check_gguf_file
  15. APHRODITE_USE_MODELSCOPE = os.getenv("APHRODITE_USE_MODELSCOPE", "0") == "1"
  16. if APHRODITE_USE_MODELSCOPE:
  17. from modelscope import AutoConfig
  18. else:
  19. from transformers import AutoConfig
  20. _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
  21. "chatglm": ChatGLMConfig,
  22. "dbrx": DbrxConfig,
  23. "mpt": MPTConfig,
  24. "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
  25. "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
  26. "jais": JAISConfig,
  27. "mlp_speculator": MLPSpeculatorConfig,
  28. "medusa": MedusaConfig,
  29. "internvl_chat": InternVLChatConfig,
  30. }
  31. for name, cls in _CONFIG_REGISTRY.items():
  32. with contextlib.suppress(ValueError):
  33. AutoConfig.register(name, cls)
  34. def get_config(
  35. model: Union[str, Path],
  36. trust_remote_code: bool,
  37. revision: Optional[str] = None,
  38. code_revision: Optional[str] = None,
  39. rope_scaling: Optional[dict] = None,
  40. rope_theta: Optional[float] = None,
  41. **kwargs,
  42. ) -> PretrainedConfig:
  43. # Separate model folder from file path for GGUF models
  44. is_gguf = check_gguf_file(model)
  45. if is_gguf:
  46. kwargs["gguf_file"] = Path(model).name
  47. model = Path(model).parent
  48. try:
  49. config = AutoConfig.from_pretrained(
  50. model,
  51. trust_remote_code=trust_remote_code,
  52. revision=revision,
  53. code_revision=code_revision,
  54. **kwargs)
  55. except ValueError as e:
  56. if (not trust_remote_code and
  57. "requires you to execute the configuration file" in str(e)):
  58. err_msg = (
  59. "Failed to load the model config. If the model is a custom "
  60. "model not yet available in the HuggingFace transformers "
  61. "library, consider setting `trust_remote_code=True` in LLM "
  62. "or using the `--trust-remote-code` flag in the CLI.")
  63. raise RuntimeError(err_msg) from e
  64. else:
  65. raise e
  66. if config.model_type in _CONFIG_REGISTRY:
  67. config_class = _CONFIG_REGISTRY[config.model_type]
  68. config = config_class.from_pretrained(model,
  69. revision=revision,
  70. code_revision=code_revision)
  71. # Special architecture mapping check for GGUF models
  72. if is_gguf:
  73. if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
  74. raise RuntimeError(
  75. f"Can't get gguf config for {config.model_type}.")
  76. model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
  77. config.update({"architectures": [model_type]})
  78. for key, value in [("rope_scaling", rope_scaling),
  79. ("rope_theta", rope_theta)]:
  80. if value is not None:
  81. logger.info(f"Updating {key} from "
  82. f"{getattr(config, key, None)} to {value}")
  83. config.update({key: value})
  84. return config
  85. def get_hf_text_config(config: PretrainedConfig):
  86. """Get the "sub" config relevant to llm for multi modal models.
  87. No op for pure text models.
  88. """
  89. if hasattr(config, "text_config"):
  90. # The code operates under the assumption that text_config should have
  91. # `num_attention_heads` (among others). Assert here to fail early
  92. # if transformers config doesn't align with this assumption.
  93. assert hasattr(config.text_config, "num_attention_heads")
  94. return config.text_config
  95. else:
  96. return config
  97. def try_get_generation_config(
  98. model: str,
  99. trust_remote_code: bool,
  100. revision: Optional[str] = None,
  101. ) -> Optional[GenerationConfig]:
  102. try:
  103. return GenerationConfig.from_pretrained(
  104. model,
  105. revision=revision,
  106. )
  107. except OSError: # Not found
  108. try:
  109. config = get_config(
  110. model,
  111. trust_remote_code=trust_remote_code,
  112. revision=revision,
  113. )
  114. return GenerationConfig.from_model_config(config)
  115. except OSError: # Not found
  116. return None