1
0

loader.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import contextlib
  2. from typing import Type
  3. import torch
  4. import torch.nn as nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import ModelConfig
  7. from aphrodite.modeling.models import (LlamaForCausalLM, GPTJForCausalLM,
  8. GPTNeoXForCausalLM, MistralForCausalLM)
  9. from aphrodite.modeling.hf_downloader import (initialize_dummy_weights,
  10. get_quant_config)
  11. from aphrodite.modeling.layers.quantized_linear.utils import quant_post_init
  12. _MODEL_REGISTRY = {
  13. "LlamaForCausalLM": LlamaForCausalLM,
  14. "LLaMAForCausalLM": LlamaForCausalLM,
  15. "GPTJForCausalLM": GPTJForCausalLM,
  16. "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
  17. "MistralForCausalLM": MistralForCausalLM,
  18. }
  19. _MODEL_CLASSES_SUPPORT_QUANTIZATION = {
  20. "awq": [LlamaForCausalLM, MistralForCausalLM],
  21. "gptq": [
  22. LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM,
  23. MistralForCausalLM
  24. ],
  25. }
  26. @contextlib.contextmanager
  27. def _set_default_torch_dtype(dtype: torch.dtype):
  28. """Sets the default torch dtype to the given dtype."""
  29. old_dtype = torch.get_default_dtype()
  30. torch.set_default_dtype(dtype)
  31. yield
  32. torch.set_default_dtype(old_dtype)
  33. def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  34. architectures = getattr(config, "architectures", [])
  35. for arch in architectures:
  36. if arch in _MODEL_REGISTRY:
  37. return _MODEL_REGISTRY[arch]
  38. raise ValueError(
  39. f"Model architectures {architectures} are not supported for now. "
  40. f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
  41. def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module:
  42. model_class = _get_model_architecture(model_config.hf_config)
  43. # Get the quantization config.
  44. quant_config = None
  45. if model_config.quantization is not None:
  46. if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION[
  47. model_config.quantization]:
  48. raise ValueError(
  49. f"Quantization is not supported for {model_class}.")
  50. quant_config = get_quant_config(model_config.quantization,
  51. model_config.model,
  52. model_config.hf_config,
  53. model_config.download_dir)
  54. capability = torch.cuda.get_device_capability()
  55. capability = capability[0] * 10 + capability[1]
  56. if capability < quant_config.get_min_capability():
  57. raise ValueError(
  58. f"The quantization method {model_config.quantization} is not "
  59. "supported for the current GPU. "
  60. f"Minimum capability: {quant_config.get_min_capability()}. "
  61. f"Current capability: {capability}.")
  62. supported_dtypes = quant_config.get_supported_act_dtypes()
  63. if model_config.dtype not in supported_dtypes:
  64. raise ValueError(
  65. f"{model_config.dtype} is not supported for quantization "
  66. f"method {model_config.quantization}. Supported dtypes: "
  67. f"{supported_dtypes}")
  68. with _set_default_torch_dtype(model_config.dtype):
  69. # Create a model instance.
  70. # The weights will be initialized as empty tensors.
  71. if model_config.quantization is not None and (
  72. model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[
  73. model_config.quantization]):
  74. model = model_class(model_config.hf_config, quant_config)
  75. else:
  76. model = model_class(model_config.hf_config)
  77. if model_config.load_format == "dummy":
  78. model = model.cuda()
  79. # NOTE(woosuk): For accurate performance evaluation, we assign
  80. # random values to the weights.
  81. initialize_dummy_weights(model)
  82. else:
  83. # Load the weights from the cached or downloaded files.
  84. model.load_weights(model_config.model, model_config.download_dir,
  85. model_config.load_format, model_config.revision)
  86. model = model.cuda()
  87. if model_config.quantization is not None:
  88. quant_post_init(model, max_tokens)
  89. return model.eval()