1
0

loader.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """Utilities for selecting and loading models."""
  2. import contextlib
  3. from typing import Optional, Type
  4. import torch
  5. import torch.nn as nn
  6. from transformers import PretrainedConfig
  7. from aphrodite.common.config import ModelConfig, LoRAConfig
  8. from aphrodite.modeling.models import ModelRegistry
  9. from aphrodite.modeling.hf_downloader import (get_quant_config,
  10. initialize_dummy_weights)
  11. @contextlib.contextmanager
  12. def _set_default_torch_dtype(dtype: torch.dtype):
  13. """Sets the default torch dtype to the given dtype."""
  14. old_dtype = torch.get_default_dtype()
  15. torch.set_default_dtype(dtype)
  16. yield
  17. torch.set_default_dtype(old_dtype)
  18. def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  19. architectures = getattr(config, "architectures", [])
  20. for arch in architectures:
  21. model_cls = ModelRegistry.load_model_cls(arch)
  22. if model_cls is not None:
  23. return model_cls
  24. raise ValueError(
  25. f"Model architectures {architectures} are not supported for now. "
  26. f"Supported architectures: {ModelRegistry.get_supported_archs()}")
  27. def get_model(model_config: ModelConfig,
  28. lora_config: Optional[LoRAConfig] = None) -> nn.Module:
  29. model_class = _get_model_architecture(model_config.hf_config)
  30. # Get the (maybe quantized) linear method.
  31. linear_method = None
  32. if model_config.quantization is not None:
  33. quant_config = get_quant_config(model_config.quantization,
  34. model_config.model,
  35. model_config.hf_config,
  36. model_config.download_dir)
  37. capability = torch.cuda.get_device_capability()
  38. capability = capability[0] * 10 + capability[1]
  39. if capability < quant_config.get_min_capability():
  40. raise ValueError(
  41. f"The quantization method {model_config.quantization} is not "
  42. "supported for the current GPU. "
  43. f"Minimum capability: {quant_config.get_min_capability()}. "
  44. f"Current capability: {capability}.")
  45. supported_dtypes = quant_config.get_supported_act_dtypes()
  46. if model_config.dtype not in supported_dtypes:
  47. raise ValueError(
  48. f"{model_config.dtype} is not supported for quantization "
  49. f"method {model_config.quantization}. Supported dtypes: "
  50. f"{supported_dtypes}")
  51. linear_method = quant_config.get_linear_method()
  52. with _set_default_torch_dtype(model_config.dtype):
  53. # Create a model instance.
  54. # The weights will be initialized as empty tensors.
  55. with torch.device("cuda"):
  56. if getattr(model_class, "supports_lora", False):
  57. model = model_class(model_config.hf_config, linear_method,
  58. lora_config)
  59. elif lora_config:
  60. raise ValueError(
  61. f"Model {model_class.__name__} does not support LoRA, "
  62. "but LoRA is enabled. Support for this model may "
  63. "be added in the future. If this is important to you, "
  64. "please open an issue on github.")
  65. else:
  66. model = model_class(model_config.hf_config, linear_method)
  67. if model_config.load_format == "dummy":
  68. # NOTE: For accurate performance evaluation, we assign
  69. # random values to the weights.
  70. initialize_dummy_weights(model)
  71. else:
  72. # Load the weights from the cached or downloaded files.
  73. model.load_weights(model_config.model, model_config.download_dir,
  74. model_config.load_format, model_config.revision)
  75. return model.eval()