123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- """Utilities for selecting and loading neuron models."""
- import importlib
- import os
- from typing import Dict, List, Optional, Tuple
- import torch
- import torch.nn as nn
- import transformers
- from transformers import PretrainedConfig
- from aphrodite.common.config import (ModelConfig, ParallelConfig,
- SchedulerConfig)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.quantization import get_quantization_config
- TORCH_DTYPE_TO_NEURON_AMP = {
- "auto": "f32",
- "half": "f16",
- "float16": "f16",
- "bfloat16": "bf16",
- "float": "f32",
- "float32": "f32",
- torch.float16: "f16",
- torch.bfloat16: "bf16",
- torch.float32: "f32",
- }
- # Models supported by Neuron.
- _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
- "LlamaForCausalLM": ("transformers_neuronx.llama.model",
- "LlamaForSampling", "LlamaForCausalLM"),
- "MistralForCausalLM": ("transformers_neuronx.mistral.model",
- "MistralForSampling", "MistralForCausalLM")
- }
- class NeuronCasualLM(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- ) -> None:
- super().__init__()
- self.config = config
- self.logits_processor = LogitsProcessor(config.vocab_size,
- logits_as_input=True)
- self.sampler = Sampler()
- # Lazy initialized
- self.model: nn.Module
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- input_block_ids: torch.Tensor,
- ) -> torch.Tensor:
- logits = self.model(input_ids,
- cache_ids=positions,
- start_ids=input_block_ids)
- return logits
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(None, hidden_states, sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, model_name_or_path: str, **kwargs):
- arch = _get_model_architecture(self.config)
- neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
- _NEURON_SUPPORTED_MODELS[arch])
- neuronx_module = importlib.import_module(neuronx_module_path)
- neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
- split_model_dir = f"{model_name_or_path}-split"
- if _is_pretrained_neuron_checkpoint(model_name_or_path):
- split_model_dir = model_name_or_path
- elif not os.path.exists(f"{model_name_or_path}-split"):
- hf_model_cls = getattr(transformers, hf_model_cls_name)
- from transformers_neuronx.module import save_pretrained_split
- hf_model = hf_model_cls.from_pretrained(model_name_or_path,
- low_cpu_mem_usage=True)
- save_pretrained_split(hf_model, f"{model_name_or_path}-split")
- self.model = neuronx_model_cls.from_pretrained(split_model_dir,
- **kwargs)
- self.model.to_neuron()
- def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
- # Checking if the neuron checkpoint is saved in the old format.
- if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
- return True
- # Checking if the neuron checkpoint is saved in the new format.
- pretrained_split_files = ["config.json", "generation_config.json"]
- pretrained_split_format = ".safetensors"
- for file in pretrained_split_files:
- file_path = os.path.join(model_name_or_path, file)
- if not os.path.isfile(file_path):
- return False
- for file in os.listdir(model_name_or_path):
- if file.endswith(pretrained_split_format):
- return True
- return False
- def _get_model_architecture(config: PretrainedConfig) -> str:
- architectures = getattr(config, "architectures", [])
- for arch in architectures:
- if arch in _NEURON_SUPPORTED_MODELS:
- return arch
- raise ValueError(
- f"Model architectures {architectures} are not supported on Neuron "
- f"for now. Supported architectures: "
- f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
- def _get_buckets(env: str, default_value: List[int]) -> List[int]:
- env_value = os.getenv(env)
- if env_value is None:
- return default_value
- buckets_remove_empty = filter(
- lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
- buckets_int = map(int, buckets_remove_empty)
- buckets_list = list(buckets_int)
- return buckets_list
- def _get_default_neuron_config(model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig):
- from transformers_neuronx.config import ContinuousBatchingConfig
- from transformers_neuronx.constants import LAYOUT_BSH
- continuous_batching_config = ContinuousBatchingConfig(
- batch_size_for_shared_caches=scheduler_config.max_num_seqs)
- quant_config = dict(
- dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
- quantize_method="vector_dynamic")
- neuron_quantization_config_builder = lambda quant: get_quantization_config(
- quant).from_config(quant_config).get_quant_method(None, "")
- # TODO: Add Paged attention config to the default neuron arguments.
- default_neuron_args = dict(
- collectives_layout=LAYOUT_BSH,
- attention_layout=LAYOUT_BSH,
- fuse_qkv=True,
- quant=neuron_quantization_config_builder(model_config.quantization)
- if model_config.quantization else None,
- continuous_batching=continuous_batching_config,
- weight_tiling=bool(model_config.quantization))
- return default_neuron_args
- def _get_neuron_config_after_override(default_neuron_config,
- overridden_neuron_config):
- from transformers_neuronx.config import NeuronConfig
- overridden_neuron_config = overridden_neuron_config or {}
- default_neuron_config.update(overridden_neuron_config)
- return NeuronConfig(**default_neuron_config)
- def get_neuron_model(model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- # Create a model instance.
- model = NeuronCasualLM(model_config.hf_config)
- default_neuron_config_args = _get_default_neuron_config(
- model_config, parallel_config, scheduler_config)
- neuron_config = _get_neuron_config_after_override(
- default_neuron_config_args, model_config.override_neuron_config)
- context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
- [scheduler_config.max_model_len])
- n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
- [scheduler_config.max_model_len])
- # Load the weights from the cached or downloaded files.
- model.load_weights(model_config.model,
- tp_degree=parallel_config.tensor_parallel_size,
- amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
- neuron_config=neuron_config,
- context_length_estimate=context_length_estimates,
- n_positions=n_positions,
- batch_size=scheduler_config.max_num_seqs)
- return model.eval()
|