1
0

loader.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """Utilities for selecting and loading models."""
  2. import contextlib
  3. import gc
  4. from contextlib import nullcontext
  5. from typing import Type
  6. from loguru import logger
  7. import torch
  8. import torch.nn as nn
  9. from aphrodite.common.config import DeviceConfig, ModelConfig
  10. from aphrodite.modeling.models import ModelRegistry
  11. from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
  12. from aphrodite.modeling.hf_downloader import (
  13. get_quant_config,
  14. initialize_dummy_weights,
  15. post_init_exl2,
  16. )
  17. from aphrodite.modeling.layers.quantization.bitsandbytes import (
  18. BNBLinearMethod,
  19. replace_quant_params,
  20. )
  21. from aphrodite.distributed import (
  22. get_tensor_model_parallel_world_size, )
  23. _VISION_MODEL_CLASSES = [
  24. LlavaForConditionalGeneration,
  25. ]
  26. @contextlib.contextmanager
  27. def _set_default_torch_dtype(dtype: torch.dtype):
  28. """Sets the default torch dtype to the given dtype."""
  29. old_dtype = torch.get_default_dtype()
  30. torch.set_default_dtype(dtype)
  31. yield
  32. torch.set_default_dtype(old_dtype)
  33. def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
  34. architectures = getattr(model_config.hf_config, "architectures", [])
  35. for arch in architectures:
  36. model_cls = ModelRegistry.load_model_cls(arch)
  37. if model_cls is not None:
  38. return model_cls
  39. raise ValueError(
  40. f"Model architectures {architectures} are not supported for now. "
  41. f"Supported architectures: {ModelRegistry.get_supported_archs()}")
  42. def get_model(model_config: ModelConfig, device_config: DeviceConfig,
  43. **kwargs) -> nn.Module:
  44. lora_config = kwargs.get("lora_config", None)
  45. vision_language_config = kwargs.get("vision_language_config", None)
  46. model_class = _get_model_architecture(model_config)
  47. # Get the (maybe quantized) linear method.
  48. linear_method = None
  49. if model_config.quantization is not None:
  50. quant_config = get_quant_config(model_config)
  51. capability = torch.cuda.get_device_capability()
  52. capability = capability[0] * 10 + capability[1]
  53. if capability < quant_config.get_min_capability():
  54. raise ValueError(
  55. f"The quantization method {model_config.quantization} is not "
  56. "supported for the current GPU. "
  57. f"Minimum capability: {quant_config.get_min_capability()}. "
  58. f"Current capability: {capability}.")
  59. supported_dtypes = quant_config.get_supported_act_dtypes()
  60. if model_config.dtype not in supported_dtypes:
  61. # set the dtype to float16 for quantized models
  62. model_config.dtype = torch.float16
  63. logger.warning("Model is quantized. Forcing float16 datatype.")
  64. linear_method = quant_config.get_linear_method()
  65. with _set_default_torch_dtype(model_config.dtype):
  66. # Create a model instance.
  67. # The weights will be initialized as empty tensors.
  68. with torch.device(device_config.device) if not (
  69. isinstance(linear_method, BNBLinearMethod)
  70. and linear_method.quant_config.from_float) else nullcontext():
  71. if hasattr(model_class, "supported_lora_modules"):
  72. model = model_class(model_config.hf_config, linear_method,
  73. lora_config)
  74. elif lora_config:
  75. raise ValueError(
  76. f"Model {model_class.__name__} does not support LoRA, "
  77. "but LoRA is enabled. Support for this model may "
  78. "be added in the future. If this is important to you, "
  79. "please open an issue on github.")
  80. else:
  81. if model_class not in _VISION_MODEL_CLASSES:
  82. model = model_class(model_config.hf_config, linear_method)
  83. else:
  84. model = model_class(model_config.hf_config,
  85. vision_language_config, linear_method)
  86. if model_config.load_format == "dummy":
  87. # NOTE: For accurate performance evaluation, we assign
  88. # random values to the weights.
  89. initialize_dummy_weights(model)
  90. else:
  91. # Load the weights from the cached or downloaded files.
  92. model.load_weights(model_config.model, model_config.download_dir,
  93. model_config.load_format, model_config.revision)
  94. # Patch for exl2 tensor parallel
  95. if model_config.quantization == "exl2":
  96. for _, module in model.named_modules():
  97. if "RowParallelLinear" in str(module.__class__):
  98. post_init_exl2(module)
  99. if isinstance(linear_method, BNBLinearMethod):
  100. replace_quant_params(
  101. model,
  102. quant_config=linear_method.quant_config,
  103. modules_to_not_convert="lm_head",
  104. )
  105. torch.cuda.synchronize()
  106. if linear_method.quant_config.from_float:
  107. model = model.cuda()
  108. gc.collect()
  109. torch.cuda.empty_cache()
  110. tp = get_tensor_model_parallel_world_size()
  111. logger.info(
  112. "Memory allocated for converted model: {} GiB x {} = {} "
  113. "GiB".format(
  114. round(
  115. torch.cuda.memory_allocated(
  116. torch.cuda.current_device()) /
  117. (1024 * 1024 * 1024),
  118. 2,
  119. ),
  120. tp,
  121. round(
  122. torch.cuda.memory_allocated(
  123. torch.cuda.current_device()) * tp /
  124. (1024 * 1024 * 1024),
  125. 2,
  126. ),
  127. ))
  128. logger.info(
  129. "Memory reserved for converted model: {} GiB x {} = {} "
  130. "GiB".format(
  131. round(
  132. torch.cuda.memory_reserved(torch.cuda.current_device())
  133. / (1024 * 1024 * 1024),
  134. 2,
  135. ),
  136. tp,
  137. round(
  138. torch.cuda.memory_reserved(torch.cuda.current_device())
  139. * tp / (1024 * 1024 * 1024),
  140. 2,
  141. ),
  142. ))
  143. return model.eval()