123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- """Utilities for selecting and loading neuron models."""
- import importlib
- import os
- from typing import Optional, Type
- import torch
- import torch.nn as nn
- import transformers
- from transformers import PretrainedConfig
- from aphrodite.common.config import ModelConfig, ParallelConfig, SchedulerConfig
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.common.sequence import SamplerOutput
- TORCH_DTYPE_TO_NEURON_AMP = {
- "auto": "f32",
- "half": "f16",
- "float16": "f16",
- "bfloat16": "bf16",
- "float": "f32",
- "float32": "f32",
- torch.float16: "f16",
- torch.bfloat16: "bf16",
- torch.float32: "f32",
- }
- # Models supported by Neuron.
- _NEURON_SUPPORTED_MODELS = {
- "LlamaForCausalLM": ("transformers_neuronx.llama.model",
- "LlamaForSampling", "LlamaForCausalLM"),
- "MistralForCausalLM": ("transformers_neuronx.mistral.model",
- "MistralForSampling", "MistralForCausalLM")
- }
- class NeuronCasualLM(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- ) -> None:
- super().__init__()
- self.config = config
- self.model = None
- self.logits_processor = LogitsProcessor(config.vocab_size,
- logits_as_input=True)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- input_block_ids: torch.Tensor,
- ) -> torch.Tensor:
- logits = self.model(input_ids,
- cache_ids=positions,
- start_ids=input_block_ids)
- return logits
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- logits = self.logits_processor(None, hidden_states, sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, model_name_or_path: str, **kwargs):
- arch = _get_model_architecture(self.config)
- neuronx_module_path, neuronx_model_cls, hf_model_cls = (
- _NEURON_SUPPORTED_MODELS[arch])
- neuronx_module = importlib.import_module(neuronx_module_path)
- neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
- split_model_dir = f"{model_name_or_path}-split"
- if os.path.isdir(os.path.join(model_name_or_path,
- "pytorch_model.bin")):
- split_model_dir = model_name_or_path
- elif not os.path.exists(f"{model_name_or_path}-split"):
- hf_model_cls = getattr(transformers, hf_model_cls)
- from transformers_neuronx.module import save_pretrained_split
- hf_model = hf_model_cls.from_pretrained(model_name_or_path,
- low_cpu_mem_usage=True)
- save_pretrained_split(hf_model, f"{model_name_or_path}-split")
- self.model = neuronx_model_cls.from_pretrained(split_model_dir,
- **kwargs)
- self.model.to_neuron()
- def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
- architectures = getattr(config, "architectures", [])
- for arch in architectures:
- if arch in _NEURON_SUPPORTED_MODELS:
- return arch
- raise ValueError(
- f"Model architectures {architectures} are not supported on Neuron "
- f"for now. Supported architectures: "
- f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
- def get_neuron_model(model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig) -> nn.Module:
- from transformers_neuronx.config import (NeuronConfig,
- ContinuousBatchingConfig)
- # Create a model instance.
- model = NeuronCasualLM(model_config.hf_config)
- continuous_batching_config = ContinuousBatchingConfig(
- batch_size_for_shared_caches=scheduler_config.max_num_seqs)
- neuron_config = NeuronConfig(
- continuous_batching=continuous_batching_config)
- # Load the weights from the cached or downloaded files.
- model.load_weights(
- model_config.model,
- tp_degree=parallel_config.tensor_parallel_size,
- amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
- neuron_config=neuron_config,
- context_length_estimate=[scheduler_config.max_model_len],
- n_positions=[scheduler_config.max_model_len],
- batch_size=scheduler_config.max_num_seqs,
- )
- return model.eval()
|