loader.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import contextlib
  2. from typing import Type
  3. import torch
  4. from torch.multiprocessing import Value
  5. import torch.nn as nn
  6. from transformers import PretrainedConfig
  7. from aphrodite.common.config import ModelConfig
  8. from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
  9. from aphrodite.modeling.hf_downloader import initialize_dummy_weights, get_quant_config
  10. _MODEL_REGISTRY = {
  11. "LlamaForCausalLM": LlamaForCausalLM,
  12. "LLaMAForCausalLM": LlamaForCausalLM,
  13. "GPTJForCausalLM": GPTJForCausalLM,
  14. "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
  15. }
  16. _QUANT_REGISTRY = {
  17. "LlamaForCausalLM",
  18. }
  19. @contextlib.contextmanager
  20. def _set_default_torch_dtype(dtype: torch.dtype):
  21. """Sets the default torch dtype to the given dtype."""
  22. old_dtype = torch.get_default_dtype()
  23. torch.set_default_dtype(dtype)
  24. yield
  25. torch.set_default_dtype(old_dtype)
  26. def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
  27. architectures = getattr(config, "architectures", [])
  28. for arch in architectures:
  29. if arch in _MODEL_REGISTRY:
  30. return _MODEL_REGISTRY[arch]
  31. raise ValueError(
  32. f"Model architectures {architectures} are not supported for now. "
  33. f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
  34. def get_model(model_config: ModelConfig) -> nn.Module:
  35. model_class = _get_model_architecture(model_config.hf_config)
  36. quant_config = None
  37. if model_config.quantization is not None:
  38. if model_class not in _QUANT_REGISTRY:
  39. raise ValueError(
  40. f"Quantization is not supported for {model_class}.")
  41. quant_config = get_quant_config(model_config.quantization,
  42. model_config.model,
  43. model_config.download_dir)
  44. supported_dtypes = quant_config.get_supported_act_dtypes()
  45. if model_config.dtype not in supported_dtypes:
  46. raise ValueError(
  47. f"{model_config.dtype} is not supported for quantization method {model_config.quantization}. "
  48. f"Supported datatypes: {supported_dtypes}")
  49. with _set_default_torch_dtype(model_config.dtype):
  50. # Create a model instance.
  51. # The weights will be initialized as empty tensors.
  52. if model_class in _QUANT_REGISTRY:
  53. model = model_class(model_config.hf_config, quant_config)
  54. else:
  55. model = model_class(model_config.hf_config)
  56. if model_config.load_format == "dummy":
  57. model = model.cuda()
  58. # NOTE: For accurate performance evaluation, we assign
  59. # random values to the weights.
  60. initialize_dummy_weights(model)
  61. else:
  62. # Load the weights from the cached or downloaded files.
  63. model.load_weights(model_config.model, model_config.download_dir,
  64. model_config.load_format)
  65. model = model.cuda()
  66. return model.eval()