Переглянути джерело

vLLM Upstream Sync (#526)

* sync: part 1

* fix: aqlm kernel file

* smh

* typo in import

* more typos in imports

* more...

* implement LoaderConfig

* silly me

* one more import

* ray_utils -> ray_tools

* LoadConfig in arg tools

* max_log_probs -> max_logprobs

* extra loader config in args

* model_config -> load_config

* a few more imports

* some fixes for ops -> cache_ops

* worker -> task_handler

* testing sq+ fix

* Revert "testing sq+ fix"

This reverts commit 6898c3a8689bc3c60c3782b1cf0354443e727c5e.

* sq+: attempt 2

* sq+: attempt 2: indent fix
AlpinDale 8 місяців тому
батько
коміт
fca911ee0a
100 змінених файлів з 8383 додано та 3968 видалено
  1. 2 2
      CMakeLists.txt
  2. 1 1
      aphrodite/attention/backends/abstract.py
  3. 15 11
      aphrodite/attention/backends/flash_attn.py
  4. 36 34
      aphrodite/attention/backends/rocm_flash_attn.py
  5. 9 12
      aphrodite/attention/backends/torch_sdpa.py
  6. 11 17
      aphrodite/attention/backends/xformers.py
  7. 1 3
      aphrodite/attention/ops/paged_attn.py
  8. 35 81
      aphrodite/attention/ops/prefix_prefill.py
  9. 7 0
      aphrodite/attention/ops/triton_flash_attn.py
  10. 17 7
      aphrodite/attention/selector.py
  11. 199 203
      aphrodite/common/config.py
  12. 70 9
      aphrodite/common/logger.py
  13. 151 31
      aphrodite/common/utils.py
  14. 62 55
      aphrodite/engine/aphrodite_engine.py
  15. 79 58
      aphrodite/engine/args_tools.py
  16. 77 69
      aphrodite/engine/async_aphrodite.py
  17. 37 68
      aphrodite/engine/metrics.py
  18. 5 2
      aphrodite/engine/output_processor/interfaces.py
  19. 12 8
      aphrodite/engine/output_processor/multi_step.py
  20. 10 6
      aphrodite/engine/output_processor/single_step.py
  21. 2 0
      aphrodite/engine/output_processor/stop_checker.py
  22. 3 1
      aphrodite/engine/output_processor/util.py
  23. 10 31
      aphrodite/engine/ray_tools.py
  24. 28 20
      aphrodite/executor/cpu_executor.py
  25. 17 14
      aphrodite/executor/executor_base.py
  26. 30 17
      aphrodite/executor/gpu_executor.py
  27. 24 3
      aphrodite/executor/neuron_executor.py
  28. 107 87
      aphrodite/executor/ray_gpu_executor.py
  29. 1 1
      aphrodite/modeling/guided_decoding/outlines_logits_processors.py
  30. 0 456
      aphrodite/modeling/hf_downloader.py
  31. 8 9
      aphrodite/modeling/layers/activation.py
  32. 3 2
      aphrodite/modeling/layers/fused_moe/__init__.py
  33. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json
  34. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json
  35. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json
  36. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json
  37. 128 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json
  38. 138 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  39. 110 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
  40. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
  41. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  42. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
  43. 128 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
  44. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  45. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
  46. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  47. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
  48. 128 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
  49. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  50. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
  51. 108 67
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  52. 79 242
      aphrodite/modeling/layers/linear.py
  53. 8 9
      aphrodite/modeling/layers/logits_processor.py
  54. 8 2
      aphrodite/modeling/layers/ops/sample.py
  55. 16 21
      aphrodite/modeling/layers/rotary_embedding.py
  56. 6 6
      aphrodite/modeling/layers/sampler.py
  57. 21 55
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  58. 0 150
      aphrodite/modeling/loader.py
  59. 31 0
      aphrodite/modeling/model_loader/__init__.py
  60. 383 0
      aphrodite/modeling/model_loader/loader.py
  61. 15 13
      aphrodite/modeling/model_loader/neuron.py
  62. 363 0
      aphrodite/modeling/model_loader/tensorizer.py
  63. 41 0
      aphrodite/modeling/model_loader/utils.py
  64. 362 0
      aphrodite/modeling/model_loader/weight_utils.py
  65. 10 5
      aphrodite/modeling/models/__init__.py
  66. 79 114
      aphrodite/modeling/models/baichuan.py
  67. 16 35
      aphrodite/modeling/models/bloom.py
  68. 12 21
      aphrodite/modeling/models/chatglm.py
  69. 47 129
      aphrodite/modeling/models/commandr.py
  70. 22 33
      aphrodite/modeling/models/dbrx.py
  71. 5 12
      aphrodite/modeling/models/decilm.py
  72. 75 178
      aphrodite/modeling/models/deepseek.py
  73. 17 24
      aphrodite/modeling/models/falcon.py
  74. 95 154
      aphrodite/modeling/models/gemma.py
  75. 13 34
      aphrodite/modeling/models/gpt2.py
  76. 14 34
      aphrodite/modeling/models/gpt_bigcode.py
  77. 28 75
      aphrodite/modeling/models/gpt_j.py
  78. 22 42
      aphrodite/modeling/models/gpt_neox.py
  79. 31 84
      aphrodite/modeling/models/internlm2.py
  80. 333 0
      aphrodite/modeling/models/jais.py
  81. 39 124
      aphrodite/modeling/models/llama.py
  82. 16 19
      aphrodite/modeling/models/llava.py
  83. 531 0
      aphrodite/modeling/models/minicpm.py
  84. 121 199
      aphrodite/modeling/models/mixtral.py
  85. 404 0
      aphrodite/modeling/models/mixtral_quant.py
  86. 16 34
      aphrodite/modeling/models/mpt.py
  87. 55 86
      aphrodite/modeling/models/olmo.py
  88. 41 107
      aphrodite/modeling/models/opt.py
  89. 318 0
      aphrodite/modeling/models/orion.py
  90. 45 112
      aphrodite/modeling/models/phi.py
  91. 43 101
      aphrodite/modeling/models/qwen.py
  92. 83 136
      aphrodite/modeling/models/qwen2.py
  93. 12 21
      aphrodite/modeling/models/qwen2_moe.py
  94. 63 150
      aphrodite/modeling/models/stablelm.py
  95. 303 0
      aphrodite/modeling/models/starcoder2.py
  96. 365 0
      aphrodite/modeling/models/xverse.py
  97. 3 4
      aphrodite/modeling/sampling_metadata.py
  98. 26 5
      aphrodite/processing/block/block_table.py
  99. 18 2
      aphrodite/processing/block/common.py
  100. 22 11
      aphrodite/processing/block/cpu_gpu_block_allocator.py

+ 2 - 2
CMakeLists.txt

@@ -193,12 +193,12 @@ set (APHRODITE_QUANT_EXT_SRC
   "kernels/quantization/squeezellm/quant_cuda_kernel.cu"
   "kernels/quantization/exl2/q_matrix.cu"
   "kernels/quantization/exl2/q_gemm_exl2.cu"
+  "kernels/quantization/fp8/fp8_cuda_kernels.cu"
   "kernels/quantization/quant_ops.cpp")
 
 if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   list(APPEND APHRODITE_QUANT_EXT_SRC
-    "kernels/quantization/aqlm/aqlm_cuda_entry.cpp"
-    "kernels/quantization/aqlm/aqlm_cuda_kernel.cu"
+    "kernels/quantization/aqlm/gemm_kernels.cu"
     "kernels/quantization/awq/gemm_kernels.cu"
     "kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu"
     "kernels/quantization/bitsandbytes/format.cu"

+ 1 - 1
aphrodite/attention/backends/abstract.py

@@ -116,7 +116,7 @@ class AttentionImpl(ABC):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: torch.Tensor,
-        attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
+        attn_metadata: AttentionMetadata,
         kv_scale: float,
     ) -> torch.Tensor:
         raise NotImplementedError

+ 15 - 11
aphrodite/attention/backends/flash_attn.py

@@ -1,4 +1,5 @@
 """Attention layer with Flash and PagedAttention.
+
 NOTE: At the moment, this file includes a lot of duplicated code from
 XFormers backend. The duplicated code will be removed once we use flash-attn or
 flashinfer for all the attention operations.
@@ -9,16 +10,12 @@ from typing import Dict, List, Optional, Tuple, Type
 import torch
 from flash_attn import flash_attn_varlen_func
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataPerStage)
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class FlashAttentionBackend(AttentionBackend):
@@ -61,6 +58,7 @@ class FlashAttentionBackend(AttentionBackend):
 class FlashAttentionMetadata(AttentionMetadataPerStage,
                              PagedAttentionMetadata):
     """Metadata for FlashAttentionBackend.
+
     NOTE: Any python object stored here is not updated when it is
     cuda-graph replayed. If you have values that need to be changed
     dynamically, it should be stored in tensor. The tensor has to be
@@ -110,17 +108,23 @@ class FlashAttentionImpl(AttentionImpl):
     If the input tensors contain prompt tokens, the layout is as follows:
     |<--------------- num_prefill_tokens ----------------->|	
     |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
+
     Otherwise, the layout is as follows:	
     |<----------------- num_decode_tokens ------------------>|	
     |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
+
     Generation tokens can contain padding when cuda-graph is used.
     Currently, prompt tokens don't contain any padding.
+
     The prompts might have different lengths, while the generation tokens
     always have length 1.
+
     If chunked prefill is enabled, prefill tokens and decode tokens can be
     batched together in a flattened 1D query.
+
     |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
     |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+
     Currently, cuda graph is disabled for chunked prefill, meaning there's no
     padding between prefill and decode tokens.
     """
@@ -163,6 +167,7 @@ class FlashAttentionImpl(AttentionImpl):
         kv_scale: float,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
+
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
@@ -245,7 +250,6 @@ class FlashAttentionImpl(AttentionImpl):
                     prefill_meta.context_lens,
                     prefill_meta.max_subquery_len,
                     self.alibi_slopes,
-                    self.sliding_window[0],
                 )
         if decode_meta := attn_metadata.decode_metadata:
             # Decoding run.

+ 36 - 34
aphrodite/attention/backends/rocm_flash_attn.py

@@ -6,16 +6,12 @@ from typing import Dict, List, Optional, Tuple, Type
 import torch
 from loguru import logger
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataPerStage)
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class ROCmFlashAttentionBackend(AttentionBackend):
@@ -58,6 +54,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
 class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
                                  PagedAttentionMetadata):
     """Metadata for FlashAttentionBackend.
+
     NOTE: Any python object stored here is not updated when it is
     cuda-graph replayed. If you have values that need to be changed
     dynamically, it should be stored in tensor. The tensor has to be
@@ -107,18 +104,23 @@ class ROCmFlashAttentionImpl(AttentionImpl):
     If the input tensors contain prompt tokens, the layout is as follows:
     |<--------------- num_prompt_tokens -------------->|
     |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+
     Otherwise, the layout is as follows:
     |<------------------ num_generation_tokens (M) ----------------->|
     |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+
     Generation tokens can contain padding when cuda-graph is used.
     Currently, prompt tokens don't contain any padding.
+
     The prompts might have different lengths, while the generation tokens
     always have length 1.
 
     If chunked prefill is enabled, prefill tokens and decode tokens can be
     batched together in a flattened 1D query.
+
     |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
     |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
+
     Currently, cuda graph is disabled for chunked prefill, meaning there's no
     padding between prefill and decode tokens.
     """
@@ -151,26 +153,30 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 f"Head size {head_size} is not supported by PagedAttention. "
                 f"Supported head sizes are: {suppored_head_sizes}.")
 
-        self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
+        self.use_naive_attn = False
         # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
         self.use_triton_flash_attn = (os.environ.get(
-            "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
-                                      in ("true", "1"))
-        if self.use_naive_attn:
-            # AMD Radeon 7900 series (gfx1100) currently does not support
-            # xFormers nor FlashAttention. As a temporary workaround, we use
-            # naive PyTorch implementation of attention.
-            self.attn_fuc = _naive_attention
-            logger.debug("Using naive attention in ROCmBackend")
-        elif self.use_triton_flash_attn:
-            from aphrodite.attention.ops.triton_flash_attn import (  # noqa: F401
-                triton_attention, )
+            "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
+        if self.use_triton_flash_attn:
+            from aphrodite.attention.ops.triton_flash_attention import \
+                triton_attention  # noqa: F401
             self.attn_func = triton_attention
             logger.debug("Using Triton FA in ROCmBackend")
         else:
-            from flash_attn import flash_attn_varlen_func  # noqa: F401
-            self.attn_func = flash_attn_varlen_func
-            logger.debug("Using CK FA in ROCmBackend")
+            # if not using triton, navi3x not use flash-attn either
+            if torch.cuda.get_device_capability()[0] == 11:
+                self.use_naive_attn = True
+            else:
+                try:
+                    from flash_attn import flash_attn_varlen_func  # noqa: F401
+                    self.attn_func = flash_attn_varlen_func
+                    logger.debug("Using CK FA in ROCmBackend")
+                except ModuleNotFoundError:
+                    self.use_naive_attn = True
+
+            if self.use_naive_attn:
+                self.attn_func = _naive_attention
+                logger.debug("Using naive attention in ROCmBackend")
 
     def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
         """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -190,6 +196,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         kv_scale: float = 1.0,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
+
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
@@ -240,17 +247,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
 
         if prefill_meta := attn_metadata.prefill_metadata:
             # Prompt run.
+            assert prefill_meta.prompt_lens is not None
             if kv_cache is None or prefill_meta.block_tables.numel() == 0:
                 # triton attention
                 # When block_tables are not filled, it means q and k are the
                 # prompt, and they have the same length.
-                if self.use_naive_attn or self.use_triton_flash_attn:
+                if self.use_triton_flash_attn or self.use_naive_attn:
                     if self.num_kv_heads != self.num_heads:
                         # Interleave for MQA workaround.
                         key = self.repeat_kv(key, self.num_queries_per_kv)
                         value = self.repeat_kv(value, self.num_queries_per_kv)
                     if self.use_naive_attn:
-                        out = self.attn_fuc(
+                        out = self.attn_func(
                             query,
                             key,
                             value,
@@ -302,7 +310,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                     prefill_meta.context_lens,
                     prefill_meta.max_subquery_len,
                     self.alibi_slopes,
-                    self.sliding_window[0],
                 )
 
         if decode_meta := attn_metadata.decode_metadata:
@@ -320,6 +327,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 self.alibi_slopes,
                 kv_scale,
             )
+
         # Reshape the output tensor.
         return output.view(num_tokens, hidden_size)
 
@@ -331,7 +339,6 @@ def _naive_attention(
     prompt_lens: List[int],
     scale: float,
 ) -> torch.Tensor:
-    num_tokens = query.shape[0]  # noqa: F841
     output = torch.empty_like(query)
     start = 0
     for _, prompt_len in enumerate(prompt_lens):
@@ -346,10 +353,6 @@ def _naive_attention(
         output[start:end].copy_(out)
         start += prompt_len
 
-    # Using view got RuntimeError: view size is not compatible
-    # with input tensor's size and stride (at least one
-    # dimension spans across two contiguous subspaces).
-    # Use reshape instead.
     return output
 
 
@@ -366,7 +369,6 @@ def _naive_masked_attention(
                                       device=query.device),
                            diagonal=1)
     attn_mask = attn_mask * torch.finfo(query.dtype).min
-
     attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
     attn_weights = attn_weights + attn_mask.float()
     attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)

+ 9 - 12
aphrodite/attention/backends/sdpa.py → aphrodite/attention/backends/torch_sdpa.py

@@ -6,16 +6,12 @@ from typing import Dict, List, Optional, Tuple, Type
 import torch
 from torch.nn.functional import scaled_dot_product_attention
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataPerStage)
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class TorchSDPABackend(AttentionBackend):
@@ -111,7 +107,7 @@ class TorchSDPABackendImpl(AttentionImpl):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
-        attn_metadata: TorchSDPAMetadata,
+        attn_metadata: TorchSDPAMetadata,  # type: ignore
         kv_scale: float,
     ) -> torch.Tensor:
         """Forward pass with torch SDPA and PagedAttention.
@@ -141,6 +137,7 @@ class TorchSDPABackendImpl(AttentionImpl):
                                                 kv_scale)
 
         if attn_metadata.is_prompt:
+            assert attn_metadata.prompt_lens is not None
             if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
                 if self.num_kv_heads != self.num_heads:
                     key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
@@ -214,7 +211,7 @@ def _make_alibi_bias(
     attn_biases = []
     for prompt_len in prompt_lens:
         bias = torch.arange(prompt_len, dtype=dtype)
-        # NOTE: HF uses
+        # NOTE(zhuohan): HF uses
         #     `bias = bias[None, :].repeat(prompt_len, 1)`
         # here. We find that both biases give the same results, but
         # the bias below more accurately follows the original ALiBi

+ 11 - 17
aphrodite/attention/backends/xformers.py

@@ -4,22 +4,16 @@ from typing import Dict, List, Optional, Tuple, Type
 
 import torch
 from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (
-    AttentionBias,
-    BlockDiagonalCausalMask,
-    LowerTriangularMaskWithTensorBias,
-)
-
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from xformers.ops.fmha.attn_bias import (AttentionBias,
+                                         BlockDiagonalCausalMask,
+                                         LowerTriangularMaskWithTensorBias)
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataPerStage)
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class XFormersBackend(AttentionBackend):
@@ -250,7 +244,6 @@ class XFormersImpl(AttentionImpl):
                     prefill_meta.context_lens,
                     prefill_meta.max_subquery_len,
                     self.alibi_slopes,
-                    self.sliding_window,
                 )
                 assert output[:num_prefill_tokens].shape == out.shape
                 output[:num_prefill_tokens] = out
@@ -293,6 +286,7 @@ class XFormersImpl(AttentionImpl):
             value: shape = [num_prefill_tokens, num_kv_heads, head_size]
             attn_metadata: Metadata for attention.
         """
+        assert attn_metadata.prompt_lens is not None
         original_query = query
         if self.num_kv_heads != self.num_heads:
             # GQA/MQA requires the shape [B, M, G, H, K].

+ 1 - 3
aphrodite/attention/ops/paged_attn.py

@@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Tuple
 
 import torch
 
-from aphrodite._C import cache_ops
 from aphrodite._C import ops
+from aphrodite._C import cache_ops
 from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
 
 # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -173,7 +173,6 @@ class PagedAttention:
         context_lens: torch.Tensor,
         max_subquery_len: int,
         alibi_slopes: Optional[torch.Tensor],
-        sliding_window: Optional[int],
     ) -> torch.Tensor:
         output = torch.empty_like(query)
         context_attention_fwd(
@@ -190,7 +189,6 @@ class PagedAttention:
             context_lens,
             max_subquery_len,
             alibi_slopes,
-            sliding_window,
         )
         return output
 

+ 35 - 81
aphrodite/attention/ops/prefix_prefill.py

@@ -50,7 +50,6 @@ if triton.__version__ >= "2.1.0":
         BLOCK_DMODEL: tl.constexpr,  # head size
         BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
         BLOCK_N: tl.constexpr,
-        SLIDING_WINDOW: tl.constexpr,
     ):
         cur_batch = tl.program_id(0)
         cur_head = tl.program_id(1)
@@ -63,53 +62,42 @@ if triton.__version__ >= "2.1.0":
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
         cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
 
-        # start position inside of the query
-        # generally, N goes over kv, while M goes over query_len
         block_start_loc = BLOCK_M * start_m
 
         # initialize offsets
-        # [N]; starts at 0
         offs_n = tl.arange(0, BLOCK_N)
-        # [D]; starts at 0
         offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
-        # [M]; starts at current position in query
         offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
-        # [M,D]
         off_q = (
             (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)
 
         dim_mask = tl.where(
-            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
-            0).to(tl.int1)  # [D]
+            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
 
         q = tl.load(Q + off_q,
                     mask=dim_mask[None, :] &
                     (offs_m[:, None] < cur_batch_query_len),
-                    other=0.0)  # [M,D]
+                    other=0.0)
 
-        # initialize pointer to m and l
-        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")  # [M]
-        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # [M]
-        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
-                       dtype=tl.float32)  # [M,D]
+        # # initialize pointer to m and l
+        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
 
-        # compute query against context (no causal mask here)
         for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
             start_n = tl.multiple_of(start_n, BLOCK_N)
             # -- compute qk ----
             bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                          ((start_n + offs_n) // block_size) * stride_b_loc_s,
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
-                         other=0)  # [N]
-            # [D,N]
+                         other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
                      cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
-            # [N,D]
             off_v = (
                 bn[:, None] * stride_v_cache_bs +
                 cur_kv_head * stride_v_cache_h +
@@ -118,39 +106,23 @@ if triton.__version__ >= "2.1.0":
             k = tl.load(K_cache + off_k,
                         mask=dim_mask[:, None] &
                         ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
-                        other=0.0)  # [D,N]
+                        other=0.0)
 
-            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)  # [M,N]
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
             qk += tl.dot(q, k)
             qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                           float("-inf"))
             qk *= sm_scale
-            if SLIDING_WINDOW > 0:
-                # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
-                # Q entries in sequence
-                # (start_n + offs_n[None, :]) are the positions of
-                # KV entries in sequence
-                # So the condition makes sure each entry in Q only attends
-                # to KV entries not more than SLIDING_WINDOW away.
-                #
-                # We can't use -inf here, because the
-                # sliding window may lead to the entire row being masked.
-                # This then makes m_ij contain -inf, which causes NaNs in
-                # exp().
-                qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
-                              (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
-                              -10000)
 
             # -- compute m_ij, p, l_ij
-            m_ij = tl.max(qk, 1)  # [M]
-            p = tl.exp(qk - m_ij[:, None])  # [M,N]
-            l_ij = tl.sum(p, 1)  # [M]
+            m_ij = tl.max(qk, 1)
+            p = tl.exp(qk - m_ij[:, None])
+            l_ij = tl.sum(p, 1)
             # -- update m_i and l_i
-            m_i_new = tl.maximum(m_i, m_ij)  # [M]
-            alpha = tl.exp(m_i - m_i_new)  # [M]
-            beta = tl.exp(m_ij - m_i_new)  # [M]
-            l_i_new = alpha * l_i + beta * l_ij  # [M]
-
+            m_i_new = tl.maximum(m_i, m_ij)
+            alpha = tl.exp(m_i - m_i_new)
+            beta = tl.exp(m_ij - m_i_new)
+            l_i_new = alpha * l_i + beta * l_ij
             # -- update output accumulator --
             # scale p
             p_scale = beta / l_i_new
@@ -162,7 +134,7 @@ if triton.__version__ >= "2.1.0":
             v = tl.load(V_cache + off_v,
                         mask=dim_mask[None, :] &
                         ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
-                        other=0.0)  # [N,D]
+                        other=0.0)
 
             p = p.to(v.dtype)
             acc += tl.dot(p, v)
@@ -177,10 +149,8 @@ if triton.__version__ >= "2.1.0":
         k_ptrs = K + off_k
         v_ptrs = V + off_v
 
-        # block_mask is 0 when we're already past the current query length
         block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
 
-        # compute query against itself (with causal mask)
         for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
             start_n = tl.multiple_of(start_n, BLOCK_N)
             # -- compute qk ----
@@ -193,13 +163,8 @@ if triton.__version__ >= "2.1.0":
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
             qk += tl.dot(q, k)
             qk *= sm_scale
-            # apply causal mask
             qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                           float("-inf"))
-            if SLIDING_WINDOW > 0:
-                qk = tl.where(
-                    offs_m[:, None] -
-                    (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
 
             # -- compute m_ij, p, l_ij
             m_ij = tl.max(qk, 1)
@@ -472,8 +437,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_bl,
         num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
-        BLOCK_DMODEL: tl.constexpr,  # head size
-        BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
+        BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
     ):
         # attn_bias[]
@@ -494,24 +458,21 @@ if triton.__version__ >= "2.1.0":
 
         # initialize offsets
         offs_n = tl.arange(0, BLOCK_N)
-        offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
+        offs_d = tl.arange(0, BLOCK_DMODEL)
         offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
         off_q = (
             (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)
 
-        dim_mask = tl.where(
-            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
-
-        q = tl.load(Q + off_q,
-                    mask=dim_mask[None, :] &
-                    (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
-                    other=0.0)
+        q = tl.load(
+            Q + off_q,
+            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
+            other=0.0)
 
         # # initialize pointer to m and l
         m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
         l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
-        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
 
         alibi_slope = tl.load(Alibi_slopes + cur_head)
         alibi_start_q = tl.arange(
@@ -536,9 +497,8 @@ if triton.__version__ >= "2.1.0":
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
-                        mask=dim_mask[:, None] &
-                        ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
-                        other=0.0)  # [D,N]
+                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
+                        other=0.0)
 
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
             qk += tl.dot(q, k)
@@ -572,8 +532,7 @@ if triton.__version__ >= "2.1.0":
             acc = acc * acc_scale[:, None]
             # update acc
             v = tl.load(V_cache + off_v,
-                        mask=dim_mask[None, :] &
-                        ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
+                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
                         other=0.0)
 
             p = p.to(v.dtype)
@@ -606,9 +565,8 @@ if triton.__version__ >= "2.1.0":
             # -- compute qk ----
             k = tl.load(k_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_kbs,
-                        mask=dim_mask[:, None] &
-                        ((start_n + offs_n[None, :]) <
-                         cur_batch_seq_len - cur_batch_ctx_len),
+                        mask=(start_n + offs_n[None, :]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
                         other=0.0)
 
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -644,9 +602,8 @@ if triton.__version__ >= "2.1.0":
             # update acc
             v = tl.load(v_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_vbs,
-                        mask=dim_mask[None, :] &
-                        ((start_n + offs_n[:, None]) <
-                         cur_batch_seq_len - cur_batch_ctx_len),
+                        mask=(start_n + offs_n[:, None]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
                         other=0.0)
 
             p = p.to(v.dtype)
@@ -664,8 +621,7 @@ if triton.__version__ >= "2.1.0":
         out_ptrs = Out + off_o
         tl.store(out_ptrs,
                  acc,
-                 mask=dim_mask[None, :] &
-                 (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
+                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
         return
 
     @torch.inference_mode()
@@ -680,8 +636,7 @@ if triton.__version__ >= "2.1.0":
                               b_seq_len,
                               b_ctx_len,
                               max_input_len,
-                              alibi_slopes=None,
-                              sliding_window=None):
+                              alibi_slopes=None):
 
         cap = torch.cuda.get_device_capability()
         BLOCK = 128 if cap[0] >= 8 else 64
@@ -689,7 +644,7 @@ if triton.__version__ >= "2.1.0":
         Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
         assert Lq == Lk and Lk == Lv
         # round up Lk to a power of 2 - this is required for Triton block size
-        Lk_padded = triton.next_power_of_2(Lk)
+        Lk_padded = 2**((Lk - 1).bit_length())
 
         sm_scale = 1.0 / (Lq**0.5)
         batch, head = b_seq_len.shape[0], q.shape[1]
@@ -699,6 +654,7 @@ if triton.__version__ >= "2.1.0":
 
         num_warps = 8 if Lk <= 64 else 8
         if alibi_slopes is not None:
+            assert Lk == Lk_padded
             _fwd_kernel_alibi[grid](
                 q,
                 k,
@@ -743,7 +699,6 @@ if triton.__version__ >= "2.1.0":
                 num_queries_per_kv=num_queries_per_kv,
                 BLOCK_M=BLOCK,
                 BLOCK_DMODEL=Lk,
-                BLOCK_DMODEL_PADDED=Lk_padded,
                 BLOCK_N=BLOCK,
                 num_warps=num_warps,
                 num_stages=1,
@@ -794,7 +749,6 @@ if triton.__version__ >= "2.1.0":
             BLOCK_DMODEL=Lk,
             BLOCK_DMODEL_PADDED=Lk_padded,
             BLOCK_N=BLOCK,
-            SLIDING_WINDOW=sliding_window if sliding_window is not None else 0,
             num_warps=num_warps,
             num_stages=1,
         )

+ 7 - 0
aphrodite/attention/ops/triton_flash_attn.py

@@ -2,16 +2,22 @@
 """
 Fused Attention
 ===============
+
 This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
 (https://tridao.me/publications/flash2/flash2.pdf)
 Credits: OpenAI kernel team, AMD ML Frameworks Triton team
+
 Features supported:
+
 1) Fwd with causal masking
 2) Any sequence lengths without padding (currently fwd kernel only)
 3) Support for different sequence lengths for q and k
 4) Nested tensor API currently does not support dropout or bias.
+
 Not currently supported:
+
 1) Non power of two head dims
+
 """
 
 import torch
@@ -413,6 +419,7 @@ def attn_fwd(
         off_h_k = off_h_q % hk
     else:
         off_h_k = off_h_q
+
     n_extra_tokens = 0
     if seqlen_k < BLOCK_N:
         n_extra_tokens = BLOCK_N - seqlen_k

+ 17 - 7
aphrodite/attention/selector.py

@@ -1,13 +1,15 @@
 import enum
+import os
 from functools import lru_cache
 from typing import Type
 
 import torch
 from loguru import logger
-
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.common.utils import is_cpu, is_hip
 
+APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
+
 
 class _Backend(enum.Enum):
     FLASH_ATTN = enum.auto()
@@ -21,20 +23,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
     backend = _which_attn_to_use(dtype)
     if backend == _Backend.FLASH_ATTN:
         logger.info("Using FlashAttention backend.")
-        from aphrodite.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
+        from aphrodite.attention.backends.flash_attn import \
+            FlashAttentionBackend  # noqa: F401
         return FlashAttentionBackend
     elif backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
-        from aphrodite.attention.backends.xformers import XFormersBackend  # noqa: F501
+        from aphrodite.attention.backends.xformers import \
+            XFormersBackend  # noqa: F401
         return XFormersBackend
     elif backend == _Backend.ROCM_FLASH:
-        logger.info("Using ROCm FlashAttention backend.")
-        from aphrodite.attention.backends.rocm_flash_attn import (  # noqa: F401
-            ROCmFlashAttentionBackend)
+        logger.info("Using ROCmFlashAttention backend.")
+        from aphrodite.attention.backends.rocm_flash_attn import \
+            ROCmFlashAttentionBackend  # noqa: F401
         return ROCmFlashAttentionBackend
     elif backend == _Backend.TORCH_SDPA:
         logger.info("Using Torch SDPA backend.")
-        from aphrodite.attention.backends.sdpa import TorchSDPABackend
+        from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         return TorchSDPABackend
     else:
         raise ValueError("Invalid attention backend.")
@@ -71,4 +75,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
             "Cannot use FlashAttention backend because the flash_attn package "
             "is not found. Please install it for better performance.")
         return _Backend.XFORMERS
+
+    backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
+    if backend_by_env_var is not None:
+        return _Backend[backend_by_env_var]
+
+    # Default case.
     return _Backend.FLASH_ATTN

+ 199 - 203
aphrodite/common/config.py

@@ -1,22 +1,29 @@
 import enum
-from typing import TYPE_CHECKING, Optional, Union, ClassVar
-from dataclasses import dataclass, fields
-import os
-from packaging.version import Version
-from loguru import logger
 import json
+import os
+from dataclasses import dataclass, field, fields
+from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
 
 import torch
+from loguru import logger
+from packaging.version import Version
 from transformers import PretrainedConfig
 
-from aphrodite.transformers_utils.config import get_config, get_hf_text_config
-from aphrodite.common.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
-                                    get_nvcc_cuda_version)
+from aphrodite.common.utils import (get_cpu_memory, get_nvcc_cuda_version,
+                                    is_cpu, is_hip, is_neuron)
 from aphrodite.quantization import QUANTIZATION_METHODS
+from aphrodite.transformers_utils.config import get_config, get_hf_text_config
 
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
 
+    from aphrodite.modeling.model_loader.loader import BaseModelLoader
+
+
+# If true, will load models from ModelScope instead of Hugging Face Hub.
+APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
+                                     "False").lower() == "true"
+
 _GB = 1 << 30
 
 
@@ -30,18 +37,6 @@ class ModelConfig:
             available, and "slow" will always use the slow tokenizer.
         trust_remote_code: Trust remote code (e.g., from HuggingFace) when
             downloading the model and tokenizer.
-        download_dir: Directory to download and load the weights, default to the
-            default cache directory of huggingface.
-        load_format: The format of the model weights to load:
-            "auto" will try to load the weights in the safetensors format and
-                fall back to the pytorch bin format if safetensors format is
-                not available.
-            "pt" will load the weights in the pytorch bin format.
-            "safetensors" will load the weights in the safetensors format.
-            "npcache" will load the weights in pytorch format and store
-                a numpy cache to speed up the loading.
-            "dummy" will initialize the weights with random values, which is
-                mainly for profiling.
         dtype: Data type for model weights and activations. The "auto" option
             will use FP16 precision for FP32 and FP16 models, and BF16 precision
             for BF16 models.
@@ -50,7 +45,7 @@ class ModelConfig:
             a tag name, or a commit id. If unspecified, will use the default
             version.
         code_revision: The specific revision to use for the model code on
-            Hugging Face Hub. It can be a branch name, a tag name, or a 
+            Hugging Face Hub. It can be a branch name, a tag name, or a
             commit id. If unspecified, will use the default version.
         tokenizer_revision: The specific tokenizer version to use. It can be a
             branch name, a tag name, or a commit id. If unspecified, will use
@@ -66,8 +61,8 @@ class ModelConfig:
         load_in_smooth: Whether to load the FP16 model in smoothquant format.
         quantization_param_path: Path to JSON file containing scaling factors.
             Used to load KV cache scaling factors into the model when KV cache
-            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also 
-            be used to load activation and weight scaling factors when the 
+            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
+            be used to load activation and weight scaling factors when the
             model dtype is FP8_E4M3 on ROCm.
         enforce_eager: Whether to enforce eager execution. If True, we will
             disable CUDA graph and always execute the model in eager mode.
@@ -75,6 +70,8 @@ class ModelConfig:
         max_context_len_to_capture: Maximum context len covered by CUDA graphs.
             When a sequence has context length larger than this, we fall back
             to eager mode.
+        skip_tokenizer_init: If true, skip initialization of tokenizer and
+            detokenizer.
     """
 
     def __init__(
@@ -83,9 +80,6 @@ class ModelConfig:
         tokenizer: str,
         tokenizer_mode: str,
         trust_remote_code: bool,
-        download_dir: Optional[str],
-        load_format: str,
-        # dtype: str,
         dtype: Union[str, torch.dtype],
         seed: int,
         revision: Optional[str] = None,
@@ -97,16 +91,15 @@ class ModelConfig:
         load_in_8bit: bool = False,
         load_in_smooth: bool = False,
         quantization_param_path: Optional[str] = None,
-        enforce_eager: bool = True,
+        enforce_eager: bool = False,
         max_context_len_to_capture: Optional[int] = None,
-        max_log_probs: int = 10,
+        max_logprobs: int = 5,
+        skip_tokenizer_init: bool = False,
     ) -> None:
         self.model = model
         self.tokenizer = tokenizer
         self.tokenizer_mode = tokenizer_mode
         self.trust_remote_code = trust_remote_code
-        self.download_dir = download_dir
-        self.load_format = load_format
         self.seed = seed
         self.revision = revision
         self.code_revision = code_revision
@@ -118,22 +111,8 @@ class ModelConfig:
         self.quantization_param_path = quantization_param_path
         self.enforce_eager = enforce_eager
         self.max_context_len_to_capture = max_context_len_to_capture
-        self.max_log_probs = max_log_probs
-
-        if os.environ.get("APHRODITE_USE_MODELSCOPE",
-                          "False").lower() == "true":
-            # download model from ModelScope hub,
-            # lazy import so that modelscope is not required for normal use.
-            from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
-            if not os.path.exists(model):
-                model_path = snapshot_download(model_id=model,
-                                               cache_dir=download_dir,
-                                               revision=revision)
-            else:
-                model_path = model
-            self.model = model_path
-            self.download_dir = model_path
-            self.tokenizer = model_path
+        self.max_logprobs = max_logprobs
+        self.skip_tokenizer_init = skip_tokenizer_init
 
         self.hf_config = get_config(self.model, trust_remote_code, revision,
                                     code_revision)
@@ -141,41 +120,11 @@ class ModelConfig:
         self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
         self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
                                                      max_model_len)
-        self._verify_load_format()
-        self._verify_tokenizer_mode()
+        if not self.skip_tokenizer_init:
+            self._verify_tokenizer_mode()
         self._verify_quantization()
         self._verify_cuda_graph()
 
-    def _verify_load_format(self) -> None:
-        load_format = self.load_format.lower()
-        supported_load_format = [
-            "auto", "pt", "safetensors", "npcache", "dummy"
-        ]
-        rocm_not_supported_load_format = []
-        if load_format not in supported_load_format:
-            raise ValueError(
-                f"Unknown load format: {self.load_format}. Must be one of "
-                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
-        if is_hip() and load_format in rocm_not_supported_load_format:
-            rocm_supported_load_format = [
-                f for f in supported_load_format
-                if (f not in rocm_not_supported_load_format)
-            ]
-            raise ValueError(
-                f"load format \'{load_format}\' is not supported in ROCm. "
-                f"Supported load format are "
-                f"{rocm_supported_load_format}")
-
-        # TODO: Remove this check once HF updates the pt weights of Mixtral.
-        architectures = getattr(self.hf_config, "architectures", [])
-        # architectures can be None instead of []
-        if architectures and "MixtralForCausalLM" in architectures \
-            and load_format == "pt":
-            raise ValueError(
-                "Currently, the 'pt' format is not supported for Mixtral. "
-                "Please use the 'safetensors' format instead. ")
-        self.load_format = load_format
-
     def _verify_tokenizer_mode(self) -> None:
         tokenizer_mode = self.tokenizer_mode.lower()
         if tokenizer_mode not in ["auto", "slow"]:
@@ -186,33 +135,33 @@ class ModelConfig:
 
     def _verify_quantization(self) -> None:
         supported_quantization = [*QUANTIZATION_METHODS]
-        rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
+        rocm_supported_quantization = ["gptq", "squeezellm"]
         if self.quantization is not None:
             self.quantization = self.quantization.lower()
 
-        if self.model.endswith("gguf"):
-            if self.quantization is None:
-                self.quantization = "gguf"
-            elif self.quantization != "gguf":
-                raise ValueError(
-                    f"GGUF file cannot be used in ({self.quantization}).")
-
         # Parse quantization method from the HF model config, if available.
-        hf_quant_config = getattr(self.hf_config, "quantization_config", None)
-        if hf_quant_config is not None:
-
-            hf_quant_method = str(hf_quant_config["quant_method"]).lower()
-            # If the GPTQ model is serialized in marlin format, use marlin.
-            if (hf_quant_method == "gptq"
-                    and "is_marlin_format" in hf_quant_config
-                    and hf_quant_config["is_marlin_format"]):
-                hf_quant_method = "marlin"
+        quant_cfg = getattr(self.hf_config, "quantization_config", None)
+        if quant_cfg is not None:
+            quant_method = quant_cfg.get("quant_method", "").lower()
+            # compat: autogptq >=0.8.0 use checkpoint_format: str
+            # compat: autogptq <=0.7.1 is_marlin_format: bool
+            is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
+                                or quant_cfg.get("is_marlin_format", False))
+
+            # Use marlin if the GPTQ model is serialized in marlin format.
+            if quant_method == "gptq" and is_format_marlin:
+                logger.info("The model is serialized in Marlin format. "
+                            "Using Marlin kernel.")
+                quant_method = "marlin"
+                if self.quantization == "gptq":
+                    self.quantization = quant_method
+
             if self.quantization is None:
-                self.quantization = hf_quant_method
-            elif self.quantization != hf_quant_method:
+                self.quantization = quant_method
+            elif self.quantization != quant_method:
                 raise ValueError(
                     "Quantization method specified in the model config "
-                    f"({hf_quant_method}) does not match the quantization "
+                    f"({quant_method}) does not match the quantization "
                     f"method specified in the `quantization` argument "
                     f"({self.quantization}).")
         if self.load_in_4bit:
@@ -280,7 +229,7 @@ class ModelConfig:
                     f"Unknown quantization method: {self.quantization}. Must "
                     f"be one of {supported_quantization}.")
             if is_hip(
-            ) and self.quantization in rocm_not_supported_quantization:
+            ) and self.quantization not in rocm_supported_quantization:
                 raise ValueError(
                     f"{self.quantization} quantization is currently not "
                     "supported in ROCm.")
@@ -317,6 +266,12 @@ class ModelConfig:
                 f"({pipeline_parallel_size}).")
 
     def get_sliding_window(self) -> Optional[int]:
+        """Get the sliding window size, or None if disabled.
+        """
+
+        # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
+        # addition to sliding window size. We check if that field is present
+        # and if it's False, return None.
         if (hasattr(self.hf_text_config, "use_sliding_window")
                 and not self.hf_text_config.use_sliding_window):
             return None
@@ -329,8 +284,8 @@ class ModelConfig:
         return self.hf_text_config.hidden_size
 
     def get_head_size(self) -> int:
-        if hasattr(self.hf_config, "head_dim"):
-            return self.hf_config.head_dim
+        if hasattr(self.hf_text_config, "head_dim"):
+            return self.hf_text_config.head_dim
         # FIXME: This may not be true for all models.
         return (self.hf_text_config.hidden_size //
                 self.hf_text_config.num_attention_heads)
@@ -397,11 +352,9 @@ class CacheConfig:
         gpu_memory_utilization: Fraction of GPU memory to use for the
             Aphrodite execution.
         swap_space: Size of the CPU swap space per GPU (in GiB).
-        cache_dtype: Data Type for KV cache storage.
-        cache_quant_params_path: Path to the scales and zero points
-            of KV cache quantization when cache_dtype is int8.
-        num_gpu_blocks_override: Number of GPU blocks to use. This overrides
-            the profiled num_gpu_blocks if specified. Does nothing if None.
+        cache_dtype: Data type for kv cache storage.
+        num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
+            profiled num_gpu_blocks if specified. Does nothing if None.
     """
 
     def __init__(
@@ -410,10 +363,9 @@ class CacheConfig:
         gpu_memory_utilization: float,
         swap_space: int,
         cache_dtype: str,
-        # cache_quant_params_path: Optional[str] = None,
         num_gpu_blocks_override: Optional[int] = None,
         sliding_window: Optional[int] = None,
-        context_shift: bool = False,
+        enable_prefix_caching: bool = False,
     ) -> None:
         self.block_size = block_size
         self.gpu_memory_utilization = gpu_memory_utilization
@@ -421,8 +373,7 @@ class CacheConfig:
         self.num_gpu_blocks_override = num_gpu_blocks_override
         self.cache_dtype = cache_dtype
         self.sliding_window = sliding_window
-        # self.cache_quant_params_path = cache_quant_params_path
-        self.context_shift = context_shift
+        self.enable_prefix_caching = enable_prefix_caching
         self._verify_args()
         self._verify_cache_dtype()
 
@@ -443,7 +394,6 @@ class CacheConfig:
 
     def _verify_cache_dtype(self) -> None:
         if self.cache_dtype == "auto":
-            # if self.cache_dtype in ["auto", "int8"]:
             pass
         elif self.cache_dtype == "fp8":
             if not is_hip():
@@ -484,10 +434,10 @@ class CacheConfig:
 @dataclass
 class TokenizerPoolConfig:
     """Configuration for the tokenizer pool.
-    
+
     Args:
-        pool_size: Number of tokenizer instances in the pool.
-        pool_type: Type of the tokenizer pool.
+        pool_size: Number of tokenizer workers in the pool.
+        pool_type: Type of the pool.
         extra_config: Additional config for the pool.
             The way the config will be used depends on the
             pool type.
@@ -498,7 +448,7 @@ class TokenizerPoolConfig:
 
     def __post_init__(self):
         if self.pool_type not in ("ray", ):
-            raise ValueError(f"Unknown pool type: {self.pool_type}.")
+            raise ValueError(f"Unknown pool type: {self.pool_type}")
         if not isinstance(self.extra_config, dict):
             raise ValueError("extra_config must be a dictionary.")
 
@@ -508,14 +458,15 @@ class TokenizerPoolConfig:
         tokenizer_pool_extra_config: Optional[Union[str, dict]]
     ) -> Optional["TokenizerPoolConfig"]:
         """Create a TokenizerPoolConfig from the given parameters.
-        
+
         If tokenizer_pool_size is 0, return None.
-        
+
         Args:
             tokenizer_pool_size: Number of tokenizer workers in the pool.
-            tokenizer_pool_type: Type of the tokenizer pool.
+            tokenizer_pool_type: Type of the pool.
             tokenizer_pool_extra_config: Additional config for the pool.
-                The way the config will be used depends on the pool type.
+                The way the config will be used depends on the
+                pool type. This can be a JSON string (will be parsed).
         """
         if tokenizer_pool_size:
             if isinstance(tokenizer_pool_extra_config, str):
@@ -532,6 +483,65 @@ class TokenizerPoolConfig:
         return tokenizer_pool_config
 
 
+class LoadFormat(str, enum.Enum):
+    AUTO = "auto"
+    PT = "pt"
+    SAFETENSORS = "safetensors"
+    NPCACHE = "npcache"
+    DUMMY = "dummy"
+    TENSORIZER = "tensorizer"
+
+
+@dataclass
+class LoadConfig:
+    """
+        download_dir: Directory to download and load the weights, default to the
+            default cache directory of huggingface.
+        load_format: The format of the model weights to load:
+            "auto" will try to load the weights in the safetensors format and
+                fall back to the pytorch bin format if safetensors format is
+                not available.
+            "pt" will load the weights in the pytorch bin format.
+            "safetensors" will load the weights in the safetensors format.
+            "npcache" will load the weights in pytorch format and store
+                a numpy cache to speed up the loading.
+            "dummy" will initialize the weights with random values, which is
+                mainly for profiling.
+            "tensorizer" will use CoreWeave's tensorizer library for
+                fast weight loading.
+    """
+
+    load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
+    download_dir: Optional[str] = None
+    model_loader_extra_config: Optional[Union[str, dict]] = field(
+        default_factory=dict)
+
+    def __post_init__(self):
+        model_loader_extra_config = self.model_loader_extra_config or {}
+        if isinstance(model_loader_extra_config, str):
+            self.model_loader_extra_config = json.loads(
+                model_loader_extra_config)
+        self._verify_load_format()
+
+    def _verify_load_format(self) -> None:
+        if not isinstance(self.load_format, str):
+            return
+
+        load_format = self.load_format.lower()
+        self.load_format = LoadFormat(load_format)
+
+        rocm_not_supported_load_format: List[str] = []
+        if is_hip() and load_format in rocm_not_supported_load_format:
+            rocm_supported_load_format = [
+                f for f in LoadFormat.__members__
+                if (f not in rocm_not_supported_load_format)
+            ]
+            raise ValueError(
+                f"load format '{load_format}' is not supported in ROCm. "
+                f"Supported load formats are "
+                f"{rocm_supported_load_format}")
+
+
 class ParallelConfig:
     """Configuration for the distributed execution.
 
@@ -546,7 +556,7 @@ class ParallelConfig:
             parallel and large models.
         disable_custom_all_reduce: Disable the custom all-reduce kernel and
             fall back to NCCL.
-        tokenizer_pool_config: Configuration for the tokenizer pool.
+        tokenizer_pool_config: Config for the tokenizer pool.
             If None, will use synchronous tokenization.
         ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
             https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
@@ -613,12 +623,9 @@ class SchedulerConfig:
             decoding to store KV activations of tokens which may or may not be
             accepted.
         delay_factor: Apply a delay (of delay factor multiplied by previous
-            prompt latency) before scheduling the next prompt.
-        policy: Policy of sequence scheduling (`fcfs` or `reorder`).
-        reorder_window: Allowed reorder window size (in sec) for `reorder`
-            policy.
-        enable_chunked_prefill: If True, prefill requests can be chunked
-            based on the remaining max_num_batched_tokens.
+            prompt latency) before scheduling next prompt.
+        enable_chunked_prefill: If True, prefill requests can be chunked based
+            on the remaining max_num_batched_tokens.
     """
 
     def __init__(
@@ -629,8 +636,6 @@ class SchedulerConfig:
         use_v2_block_manager: bool = False,
         num_lookahead_slots: int = 0,
         delay_factor: float = 0.0,
-        policy: str = "fcfs",
-        reorder_window: float = 0.0,
         enable_chunked_prefill: bool = False,
     ) -> None:
         if max_num_batched_tokens is not None:
@@ -645,14 +650,14 @@ class SchedulerConfig:
                 self.max_num_batched_tokens = max(max_model_len, 2048)
         if enable_chunked_prefill:
             logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
+
         self.max_num_seqs = max_num_seqs
         self.max_model_len = max_model_len
         self.use_v2_block_manager = use_v2_block_manager
         self.num_lookahead_slots = num_lookahead_slots
         self.delay_factor = delay_factor
-        self.policy = policy
-        self.reorder_window = reorder_window
         self.chunked_prefill_enabled = enable_chunked_prefill
+
         self._verify_args()
 
     def _verify_args(self) -> None:
@@ -665,17 +670,13 @@ class SchedulerConfig:
                 "max_num_batched_tokens and makes Aphrodite reject longer "
                 "sequences. Please increase max_num_batched_tokens or "
                 "decrease max_model_len.")
+
         if self.max_num_batched_tokens < self.max_num_seqs:
             raise ValueError(
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                 "be greater than or equal to max_num_seqs "
                 f"({self.max_num_seqs}).")
-        if self.reorder_window < 0:
-            raise ValueError(f"reorder_window ({self.reorder_window}) must "
-                             "be not be negative.")
-        if self.reorder_window != 0 and self.policy != 'reorder':
-            raise ValueError("fcfs policy doesn't support reorder_window "
-                             f"({self.reorder_window}).")
+
         if self.num_lookahead_slots < 0:
             raise ValueError(
                 "num_lookahead_slots "
@@ -688,14 +689,14 @@ class DeviceConfig:
     def __init__(self, device: str = "auto") -> None:
         if device == "auto":
             # Automated device type detection
-            if torch.cuda.is_available():
-                self.device_type = "cuda"
-            elif is_neuron():
+            if is_neuron():
                 self.device_type = "neuron"
             elif is_cpu():
                 self.device_type = "cpu"
             else:
-                raise RuntimeError("No supported device detected.")
+                # We don't call torch.cuda.is_available() here to
+                # avoid initializing CUDA before workers are forked
+                self.device_type = "cuda"
         else:
             # Device type is assigned explicitly
             self.device_type = device
@@ -710,6 +711,7 @@ class DeviceConfig:
 
 class SpeculativeConfig:
     """Configuration for speculative decoding.
+
     The configuration is currently specialized to draft-model speculative
     decoding with top-1 proposals.
     """
@@ -724,13 +726,13 @@ class SpeculativeConfig:
         speculative_max_model_len: Optional[int],
         enable_chunked_prefill: bool,
         use_v2_block_manager: bool,
-        ngram_prompt_lookup_max: Optional[int],
-        ngram_prompt_lookup_min: Optional[int],
     ) -> Optional["SpeculativeConfig"]:
         """Create a SpeculativeConfig if possible, else return None.
+
         This function attempts to create a SpeculativeConfig object based on the
         provided parameters. If the necessary conditions are met, it returns an
         instance of SpeculativeConfig. Otherwise, it returns None.
+
         Args:
             target_model_config (ModelConfig): The configuration of the target
                 model.
@@ -767,6 +769,7 @@ class SpeculativeConfig:
                 "Expected both speculative_model and "
                 "num_speculative_tokens to be provided, but found "
                 f"{speculative_model=} and {num_speculative_tokens=}.")
+
         assert (speculative_model is not None
                 and num_speculative_tokens is not None)
 
@@ -786,55 +789,39 @@ class SpeculativeConfig:
         draft_code_revision = None
         draft_quantization = None
 
-        if speculative_model == "[ngram]":
-            assert (ngram_prompt_lookup_max is not None
-                    and ngram_prompt_lookup_max > 0)
-            if ngram_prompt_lookup_min is None:
-                ngram_prompt_lookup_min = 0
-            else:
-                assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
-            draft_model_config = target_model_config
-            draft_parallel_config = target_parallel_config
-        else:
-            ngram_prompt_lookup_max = 0
-            ngram_prompt_lookup_min = 0
-            draft_model_config = ModelConfig(
-                model=speculative_model,
-                download_dir=target_model_config.download_dir,
-                load_format=target_model_config.load_format,
-                tokenizer=target_model_config.tokenizer,
-                tokenizer_mode=target_model_config.tokenizer_mode,
-                trust_remote_code=target_model_config.trust_remote_code,
-                dtype=target_model_config.dtype,
-                seed=target_model_config.seed,
-                revision=draft_revision,
-                code_revision=draft_code_revision,
-                tokenizer_revision=target_model_config.tokenizer_revision,
-                max_model_len=None,
-                quantization=draft_quantization,
-                enforce_eager=target_model_config.enforce_eager,
-                max_context_len_to_capture=target_model_config.
-                max_context_len_to_capture,
-                max_log_probs=target_model_config.max_log_probs,
-            )
-
-            draft_model_config.max_model_len = (
-                SpeculativeConfig._maybe_override_draft_max_model_len(
-                    speculative_max_model_len,
-                    draft_model_config.max_model_len,
-                    target_model_config.max_model_len,
-                ))
-
-            draft_parallel_config = (
-                SpeculativeConfig.create_draft_parallel_config(
-                    target_parallel_config))
+        draft_model_config = ModelConfig(
+            model=speculative_model,
+            tokenizer=target_model_config.tokenizer,
+            tokenizer_mode=target_model_config.tokenizer_mode,
+            trust_remote_code=target_model_config.trust_remote_code,
+            dtype=target_model_config.dtype,
+            seed=target_model_config.seed,
+            revision=draft_revision,
+            code_revision=draft_code_revision,
+            tokenizer_revision=target_model_config.tokenizer_revision,
+            max_model_len=None,
+            quantization=draft_quantization,
+            enforce_eager=target_model_config.enforce_eager,
+            max_context_len_to_capture=target_model_config.
+            max_context_len_to_capture,
+            max_logprobs=target_model_config.max_logprobs,
+        )
+
+        draft_model_config.max_model_len = (
+            SpeculativeConfig._maybe_override_draft_max_model_len(
+                speculative_max_model_len,
+                draft_model_config.max_model_len,
+                target_model_config.max_model_len,
+            ))
+
+        draft_parallel_config = (
+            SpeculativeConfig.create_draft_parallel_config(
+                target_parallel_config))
 
         return SpeculativeConfig(
             draft_model_config,
             draft_parallel_config,
             num_speculative_tokens,
-            ngram_prompt_lookup_max,
-            ngram_prompt_lookup_min,
         )
 
     @staticmethod
@@ -847,8 +834,10 @@ class SpeculativeConfig:
         the draft_max_model_len, but may be the target_max_model_len if it is
         less than the draft_max_model_len, or may be speculative_max_model_len
         if it is specified.
+
         This is necessary so that sequences do not exceed the capacity of the
         draft model or the target model.
+
         speculative_max_model_len is mainly used for testing that sequences can
         skip speculation.
         """
@@ -874,6 +863,7 @@ class SpeculativeConfig:
     def create_draft_parallel_config(
             target_parallel_config: ParallelConfig) -> ParallelConfig:
         """Create a parallel config for use by the draft worker.
+
         This is mostly a copy of the target parallel config. In the future the
         draft worker can have a different parallel strategy, e.g. TP=1.
         """
@@ -899,10 +889,9 @@ class SpeculativeConfig:
         draft_model_config: ModelConfig,
         draft_parallel_config: ParallelConfig,
         num_speculative_tokens: int,
-        ngram_prompt_lookup_max: int,
-        ngram_prompt_lookup_min: int,
     ):
         """Create a SpeculativeConfig object.
+
         Args:
             draft_model_config: ModelConfig for the draft model.
             draft_parallel_config: ParallelConfig for the draft model.
@@ -912,8 +901,6 @@ class SpeculativeConfig:
         self.draft_model_config = draft_model_config
         self.draft_parallel_config = draft_parallel_config
         self.num_speculative_tokens = num_speculative_tokens
-        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
-        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
 
         self._verify_args()
 
@@ -930,16 +917,14 @@ class SpeculativeConfig:
     def num_lookahead_slots(self) -> int:
         """The number of additional slots the scheduler should allocate per
         step, in addition to the slots allocated for each known token.
+
         This is equal to the number of speculative tokens, as each speculative
         token must be scored.
         """
         return self.num_speculative_tokens
 
     def __repr__(self) -> str:
-        if self.ngram_prompt_lookup_max > 0:
-            draft_model = "[ngram]"
-        else:
-            draft_model = self.draft_model_config.model
+        draft_model = self.draft_model_config.model
         num_spec_tokens = self.num_speculative_tokens
         return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
 
@@ -980,9 +965,12 @@ class LoRAConfig:
             self.lora_dtype = model_config.dtype
         elif isinstance(self.lora_dtype, str):
             self.lora_dtype = getattr(torch, self.lora_dtype)
-        if (model_config.quantization is not None
-                and model_config.quantization == "gguf"):
-            raise ValueError("LoRA is not supported with GGUF quantization.")
+        if model_config.quantization and model_config.quantization not in [
+                "awq", "gptq"
+        ]:
+            # TODO support all other quants
+            logger.warning(f"{model_config.quantization} quantization is not "
+                           "tested with LoRA yet.")
 
     def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
         if scheduler_config.max_num_batched_tokens > 65528:
@@ -999,11 +987,14 @@ class VisionLanguageConfig:
 
     class ImageInputType(enum.Enum):
         """Image input type into the vision language model.
+
         An image roughly goes through the following transformation:
         Raw image --> pixel values --> image features --> image embeddings.
+
         The difference between different image input types is where the
         image encoder (pixel values --> image features) is run.
         Different image input types also correspond to different tensor shapes.
+
         For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
         IMAGE_FEATURES: (1, 576, 1024).
         """
@@ -1075,7 +1066,7 @@ def _get_and_verify_dtype(
             k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
             if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
         ]
-        raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
+        raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
                          f"Supported dtypes are {rocm_supported_dtypes}")
 
     # Verify the dtype.
@@ -1110,16 +1101,20 @@ def _get_and_verify_max_len(
         "max_seq_len",
         # ChatGLM2
         "seq_length",
+        # Command-R
+        "model_max_length",
         # Others
         "max_sequence_length",
         "max_seq_length",
         "seq_len",
     ]
+    max_len_key = None
     for key in possible_keys:
-        max_len_key = getattr(hf_config, key, None)
-        if max_len_key is not None:
-            derived_max_model_len = min(derived_max_model_len, max_len_key)
-            break
+        max_len = getattr(hf_config, key, None)
+        if max_len is not None:
+            max_len_key = key if max_len < derived_max_model_len \
+                else max_len_key
+            derived_max_model_len = min(derived_max_model_len, max_len)
     if derived_max_model_len == float("inf"):
         if max_model_len is not None:
             # If max_model_len is specified, we use it.
@@ -1167,7 +1162,7 @@ class DecodingConfig:
         valid_guided_backends = ['outlines', 'lm-format-enforcer']
         backend = self.guided_decoding_backend
         if backend not in valid_guided_backends:
-            raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
+            raise ValueError(f"Invalid guided_decoding_backend '{backend},"
                              f"must be one of {valid_guided_backends}")
 
 
@@ -1182,6 +1177,7 @@ class EngineConfig:
     parallel_config: ParallelConfig
     scheduler_config: SchedulerConfig
     device_config: DeviceConfig
+    load_config: LoadConfig
     lora_config: Optional[LoRAConfig]
     vision_language_config: Optional[VisionLanguageConfig]
     speculative_config: Optional[SpeculativeConfig]

+ 70 - 9
aphrodite/common/logger.py

@@ -2,24 +2,24 @@
 Internal logging utility.
 """
 
+import datetime
 import logging
 import os
+import sys
+from functools import partial
+from typing import Optional
 
 from loguru import logger
 from rich.console import Console
 from rich.markup import escape
-from rich.progress import (
-    Progress,
-    TextColumn,
-    BarColumn,
-    TimeRemainingColumn,
-    TaskProgressColumn,
-    MofNCompleteColumn,
-)
+from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
+                           TaskProgressColumn, TextColumn, TimeRemainingColumn)
 
 RICH_CONSOLE = Console()
 LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()
 
+APHRODITE_CONFIGURE_LOGGING = int(os.getenv("APHRODITE_CONFIGURE_LOGGING", "1"))
+
 
 def unwrap(wrapped, default=None):
     """Unwrap function for Optionals."""
@@ -130,4 +130,65 @@ def setup_logger():
     logger.log_once = log_once
 
 
-setup_logger()
+setup_logger()
+
+
+def _trace_calls(log_path, root_dir, frame, event, arg=None):
+    if event in ['call', 'return']:
+        # Extract the filename, line number, function name, and the code object
+        filename = frame.f_code.co_filename
+        lineno = frame.f_lineno
+        func_name = frame.f_code.co_name
+        if not filename.startswith(root_dir):
+            # only log the functions in the aphrodite root_dir
+            return
+        # Log every function call or return
+        try:
+            last_frame = frame.f_back
+            if last_frame is not None:
+                last_filename = last_frame.f_code.co_filename
+                last_lineno = last_frame.f_lineno
+                last_func_name = last_frame.f_code.co_name
+            else:
+                # initial frame
+                last_filename = ""
+                last_lineno = 0
+                last_func_name = ""
+            with open(log_path, 'a') as f:
+                if event == 'call':
+                    f.write(f"{datetime.datetime.now()} Call to"
+                            f" {func_name} in {filename}:{lineno}"
+                            f" from {last_func_name} in {last_filename}:"
+                            f"{last_lineno}\n")
+                else:
+                    f.write(f"{datetime.datetime.now()} Return from"
+                            f" {func_name} in {filename}:{lineno}"
+                            f" to {last_func_name} in {last_filename}:"
+                            f"{last_lineno}\n")
+        except NameError:
+            # modules are deleted during shutdown
+            pass
+    return partial(_trace_calls, log_path, root_dir)
+
+
+def enable_trace_function_call(log_file_path: str,
+                               root_dir: Optional[str] = None):
+    """
+    Enable tracing of every function call in code under `root_dir`.
+    This is useful for debugging hangs or crashes.
+    `log_file_path` is the path to the log file.
+    `root_dir` is the root directory of the code to trace. If None, it is the
+    aphrodite root directory.
+
+    Note that this call is thread-level, any threads calling this function
+    will have the trace enabled. Other threads will not be affected.
+    """
+    logger.warning(
+        "APHRODITE_TRACE_FUNCTION is enabled. It will record every"
+        " function executed by Python. This will slow down the code. It "
+        "is suggested to be used for debugging hang or crashes only.")
+    logger.info(f"Trace frame log is saved to {log_file_path}")
+    if root_dir is None:
+        # by default, this is the aphrodite root directory
+        root_dir = os.path.dirname(os.path.dirname(__file__))
+    sys.settrace(partial(_trace_calls, log_file_path, root_dir))

+ 151 - 31
aphrodite/common/utils.py

@@ -1,15 +1,18 @@
 import asyncio
 import enum
 import gc
+import glob
 import os
 import socket
 import subprocess
 import uuid
-from collections import OrderedDict, defaultdict
+import warnings
+from collections import defaultdict
 from functools import lru_cache, partial
 from platform import uname
 from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
-                    Hashable, List, Optional, Tuple, TypeVar, Union)
+                    Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
+                    Union)
 
 import psutil
 import torch
@@ -48,7 +51,7 @@ class Counter:
 class LRUCache(Generic[T]):
 
     def __init__(self, capacity: int):
-        self.cache = OrderedDict[Hashable, T]()
+        self.cache: OrderedDict[Hashable, T] = OrderedDict()
         self.capacity = capacity
 
     def __contains__(self, key: Hashable) -> bool:
@@ -57,7 +60,7 @@ class LRUCache(Generic[T]):
     def __len__(self) -> int:
         return len(self.cache)
 
-    def __getitem__(self, key: Hashable) -> T:
+    def __getitem__(self, key: Hashable) -> Optional[T]:
         return self.get(key)
 
     def __setitem__(self, key: Hashable, value: T) -> None:
@@ -73,7 +76,7 @@ class LRUCache(Generic[T]):
             key: Hashable,
             default_value: Optional[T] = None) -> Optional[T]:
         if key in self.cache:
-            value = self.cache[key]
+            value: Optional[T] = self.cache[key]
             self.cache.move_to_end(key)
         else:
             value = default_value
@@ -84,7 +87,7 @@ class LRUCache(Generic[T]):
         self.cache.move_to_end(key)
         self._remove_old_if_needed()
 
-    def _on_remove(self, key: Hashable, value: T):
+    def _on_remove(self, key: Hashable, value: Optional[T]):
         pass
 
     def remove_oldest(self):
@@ -97,9 +100,11 @@ class LRUCache(Generic[T]):
         while len(self.cache) > self.capacity:
             self.remove_oldest()
 
-    def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
+    def pop(self,
+            key: Hashable,
+            default_value: Optional[T] = None) -> Optional[T]:
         run_on_remove = key in self.cache
-        value = self.cache.pop(key, default_value)
+        value: Optional[T] = self.cache.pop(key, default_value)
         if run_on_remove:
             self._on_remove(key, value)
         return value
@@ -156,6 +161,18 @@ def random_uuid() -> str:
     return str(uuid.uuid4().hex)
 
 
+@lru_cache(maxsize=None)
+def get_aphrodite_instance_id():
+    """
+    If the environment variable APHRODITE_INSTANCE_ID is set, return it.
+    Otherwise, return a random UUID.
+    Instance id represents an instance of the Aphrodite. All processes in the
+    same instance should have the same instance id.
+    """
+    return os.environ.get("APHRODITE_INSTANCE_ID",
+                          f"aphrodite-instance-{random_uuid()}")
+
+
 @lru_cache(maxsize=None)
 def in_wsl() -> bool:
     # Reference: https://github.com/microsoft/WSL/issues/4071
@@ -181,6 +198,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
 def merge_async_iterators(
         *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
     """Merge multiple asynchronous iterators into a single iterator.
+
     This method handle the case where some iterators finish before others.
     When it yields, it yields a tuple (i, item) where i is the index of the
     iterator that yields the item.
@@ -203,34 +221,46 @@ def merge_async_iterators(
     ]
 
     async def consumer():
-        try:
-            while not all(finished) or not queue.empty():
-                item = await queue.get()
-                if isinstance(item, Exception):
-                    raise item
-                yield item
-        except (Exception, asyncio.CancelledError) as e:
-            for task in _tasks:
-                # NOTE: Pass the error msg in cancel()
-                # when only Python 3.9+ is supported.
-                task.cancel()
-            raise e
+        while not all(finished) or not queue.empty():
+            item = await queue.get()
+            if isinstance(item, Exception):
+                raise item
+            yield item
         await asyncio.gather(*_tasks)
 
     return consumer()
 
 
 def get_ip() -> str:
+    host_ip = os.environ.get("HOST_IP")
+    if host_ip:
+        return host_ip
+
+    # IP is not set, try to get it from the network interface
+
     # try ipv4
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
     try:
         s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
         return s.getsockname()[0]
-    except OSError:
-        # try ipv6
+    except Exception:
+        pass
+
+    # try ipv6
+    try:
         s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
-        s.connect(("dns.google", 80))
+        # Google's public DNS server, see
+        # https://developers.google.com/speed/public-dns/docs/using#addresses
+        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
         return s.getsockname()[0]
+    except Exception:
+        pass
+
+    warnings.warn(
+        "Failed to get the IP address, using 0.0.0.0 by default."
+        "The value can be set by the environment variable HOST_IP.",
+        stacklevel=2)
+    return "0.0.0.0"
 
 
 def get_distributed_init_method(ip: str, port: int) -> str:
@@ -252,8 +282,12 @@ def get_open_port() -> int:
             return s.getsockname()[1]
 
 
-def set_cuda_visible_devices(device_ids: List[int]) -> None:
-    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
+def update_environment_variables(envs: Dict[str, str]):
+    for k, v in envs.items():
+        if k in os.environ and os.environ[k] != v:
+            logger.warning(f"Overwriting environment variable {k} "
+                           f"from '{os.environ[k]}' to '{v}'")
+        os.environ[k] = v
 
 
 def chunk_list(lst, chunk_size):
@@ -272,9 +306,8 @@ def get_nvcc_cuda_version() -> Optional[Version]:
     if not cuda_home:
         cuda_home = '/usr/local/cuda'
         if os.path.isfile(cuda_home + '/bin/nvcc'):
-            logger.info(
-                f'CUDA_HOME is not found in the environment. Using {cuda_home} '
-                'as CUDA_HOME.')
+            logger.info(f'CUDA_HOME is not found in the environment. '
+                        f'Using {cuda_home} as CUDA_HOME.')
         else:
             logger.warning(
                 f'Not found nvcc in {cuda_home}. Skip cuda version check!')
@@ -293,7 +326,7 @@ def _generate_random_fp8(
     high: float,
 ) -> None:
     # NOTE: Due to NaN and Inf representation for fp8 data type,
-    # we may get Inf or NaN if we directly use torch.randint
+    # it may occur Inf or NaN if we directly use torch.randint
     # to generate random data for fp8 data.
     # For example, s.11111.00 in fp8e5m2 format represents Inf.
     #     | E4M3        | E5M2
@@ -315,7 +348,7 @@ def create_kv_caches_with_random(
     head_size: int,
     cache_dtype: Optional[Union[str, torch.dtype]],
     model_dtype: Optional[Union[str, torch.dtype]] = None,
-    seed: Optional[int] = 0,
+    seed: int = 0,
     device: Optional[str] = "cuda",
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
     torch.random.manual_seed(seed)
@@ -421,7 +454,7 @@ class CudaMemoryProfiler:
         gc.collect()
 
 
-def str_to_int_tuple(s: str) -> Tuple[int]:
+def str_to_int_tuple(s: str) -> Tuple[int, ...]:
     """Convert a string to a tuple of integers."""
     try:
         return tuple(map(int, s.split(",")))
@@ -444,6 +477,7 @@ def make_tensor_with_pad(
     device: Optional[Union[str, torch.device]],
 ) -> torch.Tensor:
     """Make a padded tensor of a 2D inputs.
+
     The padding is applied to the end of each inner list until it reaches
     `max_len`.
     """
@@ -486,3 +520,89 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
         merged_dict[key].extend(value)
 
     return dict(merged_dict)
+
+
+def init_cached_hf_modules():
+    """
+    Lazy initialization of the Hugging Face modules.
+    """
+    from transformers.dynamic_module_utils import init_hf_modules
+    init_hf_modules()
+
+
+def nccl_integrity_check(filepath):
+    """
+    when the library is corrupted, we cannot catch
+    the exception in python. it will crash the process.
+    instead, we use the exit code of `ldd` to check
+    if the library is corrupted. if not, we will return
+    the version of the library.
+    """
+    exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
+    if exit_code != 0:
+        raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
+    import ctypes
+
+    nccl = ctypes.CDLL(filepath)
+    version = ctypes.c_int()
+    nccl.ncclGetVersion.restype = ctypes.c_int
+    nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
+    result = nccl.ncclGetVersion(ctypes.byref(version))
+    assert result == 0
+    return version.value
+
+
+@lru_cache(maxsize=None)
+def find_library(lib_name: str) -> str:
+    """
+    Find the library file in the system.
+    `lib_name` is full filename, with both prefix and suffix.
+    This function resolves `lib_name` to the full path of the library.
+    """
+    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
+    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
+    # `/sbin/ldconfig` should exist in all Linux systems.
+    # `/sbin/ldconfig` searches the library in the system
+    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
+    # each line looks like the following:
+    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
+    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
+    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
+    env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
+    if not locs and env_ld_library_path:
+        locs = [
+            os.path.join(dir, lib_name)
+            for dir in env_ld_library_path.split(":")
+            if os.path.exists(os.path.join(dir, lib_name))
+        ]
+    if not locs:
+        raise ValueError(f"Cannot find {lib_name} in the system.")
+    return locs[0]
+
+
+def find_nccl_library():
+    so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
+
+    aphrodite_nccl_path = None
+    if torch.version.cuda is not None:
+        cuda_major = torch.version.cuda.split(".")[0]
+        path = os.path.expanduser(
+            f"~/.config/aphrodite/nccl/cu{cuda_major}/libnccl.so.*")
+        files = glob.glob(path)
+        aphrodite_nccl_path = files[0] if files else None
+
+    # manually load the nccl library
+    if so_file:
+        logger.info(
+            "Found nccl from environment variable "
+            f"APHRODITE_NCCL_SO_PATH={so_file}"
+        )
+    else:
+        if torch.version.cuda is not None:
+            so_file = aphrodite_nccl_path or find_library("libnccl.so.2")
+        elif torch.version.hip is not None:
+            so_file = find_library("librccl.so.1")
+        else:
+            raise ValueError("NCCL only supports CUDA and ROCm backends.")
+        logger.info(f"Found nccl from library {so_file}")
+    return so_file

+ 62 - 55
aphrodite/engine/aphrodite_engine.py

@@ -6,9 +6,9 @@ from transformers import GenerationConfig, PreTrainedTokenizer
 
 import aphrodite
 from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
-                                     LoRAConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, SpeculativeConfig,
-                                     VisionLanguageConfig)
+                                     LoadConfig, LoRAConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     SpeculativeConfig, VisionLanguageConfig)
 from aphrodite.common.logger import setup_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
@@ -84,6 +84,7 @@ class AphroditeEngine:
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
+        load_config: LoadConfig,
         lora_config: Optional[LoRAConfig],
         vision_language_config: Optional[VisionLanguageConfig],
         speculative_config: Optional[SpeculativeConfig],
@@ -97,15 +98,14 @@ class AphroditeEngine:
             f"Model = {model_config.model!r}\n"
             f"Speculative Config = {speculative_config!r}\n"
             f"DataType = {model_config.dtype}\n"
-            f"Model Load Format = {model_config.load_format}\n"
+            f"Model Load Format = {load_config.load_format}\n"
             f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
             f"Disable Custom All-Reduce = "
             f"{parallel_config.disable_custom_all_reduce}\n"
             f"Quantization Format = {model_config.quantization}\n"
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
-            f"KV Cache Data Type = {cache_config.cache_dtype}\n"
-            f"KV Cache Params Path = {model_config.quantization_param_path}\n"
+            f"KV Cache DataType = {cache_config.cache_dtype}\n"
             f"Device = {device_config.device}\n"
             f"Guided Decoding Backend = {decoding_config!r}\n")
         # TODO: Print more configs in debug mode.
@@ -118,12 +118,18 @@ class AphroditeEngine:
         self.scheduler_config = scheduler_config
         self.device_config = device_config
         self.speculative_config = speculative_config
+        self.load_config = load_config
         self.decoding_config = decoding_config or DecodingConfig()
         self.log_stats = log_stats
-        self._verify_args()
 
-        self._init_tokenizer()
-        self.detokenizer = Detokenizer(self.tokenizer)
+        if not self.model_config.skip_tokenizer_init:
+            self.tokenizer: BaseTokenizerGroup
+            self._init_tokenizer()
+            self.detokenizer = Detokenizer(self.tokenizer)
+        else:
+            self.detokenizer = None
+            self.tokenizer = None
+
         self.seq_counter = Counter()
         self.generation_config_fields = _load_generation_config_dict(
             model_config)
@@ -137,13 +143,16 @@ class AphroditeEngine:
             lora_config=lora_config,
             vision_language_config=vision_language_config,
             speculative_config=speculative_config,
+            load_config=load_config,
         )
 
         self._initialize_kv_caches()
 
-        # Ping the tokenizer to ensure it is loaded if
-        # it runs on a separate process.
-        self.tokenizer.ping()
+
+        if self.tokenizer:
+            # Ping the tokenizer to ensure liveness if it runs in a
+            # different process.
+            self.tokenizer.ping()
 
         # Create the scheduler.
         # NOTE: the cache_config here have been updated with the numbers of
@@ -154,8 +163,7 @@ class AphroditeEngine:
         if self.log_stats:
             self.stat_logger = StatLogger(
                 local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
-                labels=dict(model_name=model_config.model),
-            )
+                labels=dict(model_name=model_config.model))
             self.stat_logger.info("cache_config", self.cache_config)
 
         # Create sequence output processor, e.g. for beam search or
@@ -175,6 +183,7 @@ class AphroditeEngine:
 
     def _initialize_kv_caches(self) -> None:
         """Initialize the KV cache in the worker(s).
+
         The workers will determine the number of blocks in both the GPU cache
         and the swap CPU cache.
         """
@@ -193,7 +202,10 @@ class AphroditeEngine:
         self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
 
     @classmethod
-    def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
+    def from_engine_args(
+        cls,
+        engine_args: EngineArgs,
+    ) -> "AphroditeEngine":
         """Creates an LLM engine from the engine arguments."""
         # Create the engine configs.
         engine_config = engine_args.create_engine_config()
@@ -216,9 +228,11 @@ class AphroditeEngine:
             executor_class = GPUExecutor
 
         # Create the LLM engine.
-        engine = cls(**engine_config.to_dict(),
-                     executor_class=executor_class,
-                     log_stats=not engine_args.disable_log_stats)
+        engine = cls(
+            **engine_config.to_dict(),
+            executor_class=executor_class,
+            log_stats=not engine_args.disable_log_stats,
+        )
         return engine
 
     def __reduce__(self):
@@ -241,19 +255,11 @@ class AphroditeEngine:
             max_input_length=None,
             tokenizer_mode=self.model_config.tokenizer_mode,
             trust_remote_code=self.model_config.trust_remote_code,
-            revision=self.model_config.revision)
+            revision=self.model_config.tokenizer_revision)
         init_kwargs.update(tokenizer_init_kwargs)
-        self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
+        self.tokenizer = get_tokenizer_group(
             self.parallel_config.tokenizer_pool_config, **init_kwargs)
 
-        if len(self.get_tokenizer()) != self.model_config.get_vocab_size():
-            logger.warning(
-                f"The tokenizer's vocabulary size {len(self.get_tokenizer())}"
-                f" does not match the model's vocabulary size "
-                f"{self.model_config.get_vocab_size()}.")
-        self.model_config.hf_config.tokenizer_vocab_size = len(
-            self.get_tokenizer())
-
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
@@ -264,7 +270,7 @@ class AphroditeEngine:
 
     def encode_request(
         self,
-        request_id: str,
+        request_id: str,  # pylint: disable=unused-argument
         prompt: Optional[str],
         prompt_token_ids: Optional[List[int]] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -301,7 +307,7 @@ class AphroditeEngine:
                 use the tokenizer to convert the prompts to token IDs.
             arrival_time: The arrival time of the request. If None, we use
                 the current monotonic time.
-            multi_modal_data: The multimodal data for the request.
+            multi_modal_data: Multi modal data per request.
 
         Details:
             - Set arrival_time to the current time if it is None.
@@ -330,41 +336,38 @@ class AphroditeEngine:
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                              "not enabled!")
-        max_log_probs = self.get_model_config().max_log_probs
+        max_logprobs = self.get_model_config().max_logprobs
         if (sampling_params.logprobs
-                and sampling_params.logprobs > max_log_probs) or (
+                and sampling_params.logprobs > max_logprobs) or (
                     sampling_params.prompt_logprobs
-                    and sampling_params.prompt_logprobs > max_log_probs):
+                    and sampling_params.prompt_logprobs > max_logprobs):
             raise ValueError(f"Cannot request more than "
-                             f"{max_log_probs} logprobs. "
-                             "Please increase the max_log_probs.")
+                             f"{max_logprobs} logprobs.")
         if arrival_time is None:
-            arrival_time = time.monotonic()
+            arrival_time = time.time()
         prompt_token_ids = self.encode_request(
             request_id=request_id,
             prompt=prompt,
             prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request,
-        )
+            lora_request=lora_request)
 
         # Create the sequences.
         block_size = self.cache_config.block_size
         seq_id = next(self.seq_counter)
-        eos_token_id = self.tokenizer.get_lora_tokenizer(
-            lora_request).eos_token_id
-        seq = Sequence(
-            seq_id,
-            prompt,
-            prompt_token_ids,
-            block_size,
-            eos_token_id,
-            lora_request,
-        )
+        eos_token_id = None
+        if self.tokenizer:
+            eos_token_id = self.tokenizer.get_lora_tokenizer(
+                lora_request).eos_token_id
+        else:
+            logger.warning("Use None for EOS token id because tokenizer is "
+                           "not initialized")
+        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
+                       eos_token_id, lora_request)
 
         # Defensive copy of SamplingParams, which are used by the sampler,
         # this doesn't deep-copy LogitsProcessor objects
         sampling_params = sampling_params.clone()
-        # Inject the eos token id into the sampling_params to support min_tokens
+        # inject the eos token id into the sampling_params to support min_tokens
         # processing
         sampling_params.eos_token_id = seq.eos_token_id
         sampling_params.update_from_generation_config(
@@ -413,20 +416,24 @@ class AphroditeEngine:
             scheduled_seq_groups: List[SequenceGroup],
             ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
         """Apply the model output to the sequences in the scheduled seq groups.
-        
+
         Returns RequestOutputs that can be returned to the client.
         """
+
         now = time.time()
+
         # Organize outputs by [sequence group][step] instead of
         # [step][sequence group].
         output_by_sequence_group = create_output_by_sequence_group(
             sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
+
         # Update the scheduled sequence groups with the model outputs.
         for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
                                                 output_by_sequence_group):
             seq_group = scheduled_seq_group.seq_group
             seq_group.update_num_computed_tokens(
                 scheduled_seq_group.token_chunk_size)
+
             # If all sequences in the sequence group are in DECODE, then we can
             # process the output tokens. Otherwise, they are (chunked) prefill
             # samples and should not be processed.
@@ -447,7 +454,6 @@ class AphroditeEngine:
         for seq_group in ignored_seq_groups:
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
-
         return request_outputs
 
     def step(self) -> List[RequestOutput]:
@@ -534,13 +540,14 @@ class AphroditeEngine:
             scheduler_outputs: Optional[SchedulerOutputs],
             model_output: Optional[List[SamplerOutput]] = None) -> Stats:
         """Get Stats to be Logged to Prometheus.
+
         Args:
             scheduler_outputs: Optional, used to populate metrics related to
                 the scheduled batch,
             model_output: Optional, used to emit speculative decoding metrics
                 which are created by the workers.
         """
-        now = time.monotonic()
+        now = time.time()
 
         # KV Cache Usage in %.
         num_total_gpu = self.cache_config.num_gpu_blocks
@@ -548,10 +555,10 @@ class AphroditeEngine:
         gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
 
         num_total_cpu = self.cache_config.num_cpu_blocks
-        cpu_cache_usage = 0.0
+        cpu_cache_usage = 0.
         if num_total_cpu > 0:
-            num_free_cpu = (
-                self.scheduler.block_manager.get_num_free_cpu_blocks())
+            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
+            )
             cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
 
         # Scheduler State

+ 79 - 58
aphrodite/engine/args_tools.py

@@ -4,10 +4,10 @@ from dataclasses import dataclass
 from typing import Optional
 
 from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
-                                     EngineConfig, LoRAConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     SpeculativeConfig, TokenizerPoolConfig,
-                                     VisionLanguageConfig)
+                                     EngineConfig, LoadConfig, LoRAConfig,
+                                     ModelConfig, ParallelConfig,
+                                     SchedulerConfig, SpeculativeConfig,
+                                     TokenizerPoolConfig, VisionLanguageConfig)
 from aphrodite.common.utils import str_to_int_tuple
 from aphrodite.quantization import QUANTIZATION_METHODS
 
@@ -18,6 +18,7 @@ class EngineArgs:
 
     model: str
     tokenizer: Optional[str] = None
+    skip_tokenizer_init: bool = False
     tokenizer_mode: str = "auto"
     trust_remote_code: bool = False
     download_dir: Optional[str] = None
@@ -32,13 +33,13 @@ class EngineArgs:
     tensor_parallel_size: int = 1
     max_parallel_loading_workers: Optional[int] = None
     block_size: int = 16
-    context_shift: bool = False
+    enable_prefix_caching: bool = False
     use_v2_block_manager: bool = False
     swap_space: int = 4  # GiB
     gpu_memory_utilization: float = 0.90
     max_num_batched_tokens: Optional[int] = None
     max_num_seqs: int = 256
-    max_log_probs: int = 10  # OpenAI default is 5, setting to 10 because ST
+    max_logprobs: int = 10  # OpenAI default is 5, setting to 10 because ST
     disable_log_stats: bool = False
     revision: Optional[str] = None
     code_revision: Optional[str] = None
@@ -63,6 +64,7 @@ class EngineArgs:
     ray_workers_use_nsight: bool = False
     num_gpu_blocks_override: Optional[int] = None
     num_lookahead_slots: int = 0
+    model_loader_extra_config: Optional[dict] = None
     # Related to Vision-language models such as llava
     image_input_type: Optional[str] = None
     image_token_id: Optional[int] = None
@@ -75,8 +77,6 @@ class EngineArgs:
     speculative_model: Optional[str] = None
     num_speculative_tokens: Optional[int] = None
     speculative_max_model_len: Optional[int] = None
-    ngram_prompt_lookup_max: Optional[int] = None
-    ngram_prompt_lookup_min: Optional[int] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -103,6 +103,10 @@ class EngineArgs:
             default=EngineArgs.tokenizer,
             help="name or path of the huggingface tokenizer to use",
         )
+        parser.add_argument(
+            "--skip-tokenizer-init",
+            action="store_true",
+            help="Skip initialization of tokenizer and detokenizer")
         parser.add_argument(
             "--revision",
             type=str,
@@ -150,33 +154,40 @@ class EngineArgs:
             "huggingface",
         )
         parser.add_argument(
-            "--load-format",
+            '--load-format',
             type=str,
             default=EngineArgs.load_format,
-            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
-            help="The format of the model weights to load. "
-            '"auto" will try to load the weights in the safetensors format '
-            "and fall back to the pytorch bin format if safetensors format "
-            "is not available. "
-            '"pt" will load the weights in the pytorch bin format. '
-            '"safetensors" will load the weights in the safetensors format. '
-            '"npcache" will load the weights in pytorch format and store '
-            "a numpy cache to speed up the loading. "
-            '"dummy" will initialize the weights with random values, '
-            "which is mainly for profiling.",
-        )
-        parser.add_argument(
-            "--dtype",
+            choices=[
+                'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
+            ],
+            help='The format of the model weights to load.\n\n'
+            '* "auto" will try to load the weights in the safetensors format '
+            'and fall back to the pytorch bin format if safetensors format '
+            'is not available.\n'
+            '* "pt" will load the weights in the pytorch bin format.\n'
+            '* "safetensors" will load the weights in the safetensors format.\n'
+            '* "npcache" will load the weights in pytorch format and store '
+            'a numpy cache to speed up the loading.\n'
+            '* "dummy" will initialize the weights with random values, '
+            'which is mainly for profiling.\n'
+            '* "tensorizer" will load the weights using tensorizer from '
+            'CoreWeave which assumes tensorizer_uri is set to the location of '
+            'the serialized weights.')
+        parser.add_argument(
+            '--dtype',
             type=str,
             default=EngineArgs.dtype,
             choices=[
-                "auto", "half", "float16", "bfloat16", "float", "float32"
+                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
             ],
-            help="data type for model weights and activations. "
-            'The "auto" option will use FP16 precision '
-            "for FP32 and FP16 models, and BF16 precision "
-            "for BF16 models.",
-        )
+            help='Data type for model weights and activations.\n\n'
+            '* "auto" will use FP16 precision for FP32 and FP16 models, and '
+            'BF16 precision for BF16 models.\n'
+            '* "half" for FP16. Recommended for AWQ quantization.\n'
+            '* "float16" is the same as "half".\n'
+            '* "bfloat16" for a balance between precision and range.\n'
+            '* "float" is shorthand for FP32 precision.\n'
+            '* "float32" for FP32 precision.')
         parser.add_argument(
             '--kv-cache-dtype',
             type=str,
@@ -210,7 +221,11 @@ class EngineArgs:
             default='outlines',
             choices=['outlines', 'lm-format-enforcer'],
             help='Which engine will be used for guided decoding'
-            ' (JSON schema / regex etc)')
+            ' (JSON schema / regex etc) by default. Currently support '
+            'https://github.com/outlines-dev/outlines and '
+            'https://github.com/noamgat/lm-format-enforcer.'
+            ' Can be overridden per request via guided_decoding_backend'
+            ' parameter.')
         # Parallel arguments
         parser.add_argument(
             "--worker-use-ray",
@@ -223,15 +238,14 @@ class EngineArgs:
             "-pp",
             type=int,
             default=EngineArgs.pipeline_parallel_size,
-            help="number of pipeline stages",
-        )
+            help="number of pipeline stages. Currently not supported.")
         parser.add_argument(
             "--tensor-parallel-size",
             "-tp",
             type=int,
             default=EngineArgs.tensor_parallel_size,
-            help="number of tensor parallel replicas",
-        )
+            help="number of tensor parallel replicas, i.e. the number of GPUs "
+            "to use.")
         parser.add_argument(
             "--max-parallel-loading-workers",
             type=int,
@@ -254,6 +268,7 @@ class EngineArgs:
             help="token block size",
         )
         parser.add_argument(
+            "--enable-prefix-caching",
             "--context-shift",
             action="store_true",
             help="Enable context shifting.",
@@ -308,9 +323,9 @@ class EngineArgs:
             help="maximum number of sequences per iteration",
         )
         parser.add_argument(
-            "--max-log-probs",
+            "--max-logprobs",
             type=int,
-            default=EngineArgs.max_log_probs,
+            default=EngineArgs.max_logprobs,
             help="maximum number of log probabilities to "
             "return.",
         )
@@ -500,19 +515,15 @@ class EngineArgs:
             help="The maximum sequence length supported by the "
             "draft model. Sequences over this length will skip "
             "speculation.")
-        parser.add_argument(
-            "--ngram-prompt-lookup-max",
-            type=int,
-            default=None,
-            help='Max size of window for ngram prompt lookup in speculative '
-            'decoding.')
+        parser.add_argument("--model-loader-extra-config",
+                            type=str,
+                            default=EngineArgs.model_loader_extra_config,
+                            help="Extra config for model loader. "
+                            "This will be passed to the model loader "
+                            "corresponding to the chosen load_format. "
+                            "This should be a JSON string that will be "
+                            "parsed into a dictionary.")
 
-        parser.add_argument(
-            '--ngram-prompt-lookup-min',
-            type=int,
-            default=None,
-            help='Min size of window for ngram prompt lookup in speculative '
-            'decoding.')
         return parser
 
     @classmethod
@@ -530,8 +541,6 @@ class EngineArgs:
             self.tokenizer,
             self.tokenizer_mode,
             self.trust_remote_code,
-            self.download_dir,
-            self.load_format,
             self.dtype,
             self.seed,
             self.revision,
@@ -545,8 +554,10 @@ class EngineArgs:
             self.quantization_param_path,
             self.enforce_eager,
             self.max_context_len_to_capture,
-            self.max_log_probs,
+            self.max_logprobs,
+            self.skip_tokenizer_init,
         )
+
         cache_config = CacheConfig(
             self.block_size,
             self.gpu_memory_utilization,
@@ -555,8 +566,9 @@ class EngineArgs:
             # self.kv_quant_params_path,
             self.num_gpu_blocks_override,
             model_config.get_sliding_window(),
-            self.context_shift,
+            self.enable_prefix_caching,
         )
+
         parallel_config = ParallelConfig(
             self.pipeline_parallel_size,
             self.tensor_parallel_size,
@@ -570,6 +582,7 @@ class EngineArgs:
             ),
             self.ray_workers_use_nsight,
         )
+
         speculative_config = SpeculativeConfig.maybe_create_spec_config(
             target_model_config=model_config,
             target_parallel_config=parallel_config,
@@ -579,9 +592,8 @@ class EngineArgs:
             speculative_max_model_len=self.speculative_max_model_len,
             enable_chunked_prefill=self.enable_chunked_prefill,
             use_v2_block_manager=self.use_v2_block_manager,
-            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
-            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
         )
+
         scheduler_config = SchedulerConfig(
             self.max_num_batched_tokens,
             self.max_num_seqs,
@@ -593,14 +605,21 @@ class EngineArgs:
             delay_factor=self.scheduler_delay_factor,
             enable_chunked_prefill=self.enable_chunked_prefill,
         )
-        lora_config = (LoRAConfig(
+
+        lora_config = LoRAConfig(
             max_lora_rank=self.max_lora_rank,
             max_loras=self.max_loras,
             lora_extra_vocab_size=self.lora_extra_vocab_size,
             lora_dtype=self.lora_dtype,
-            max_cpu_loras=self.max_cpu_loras
-            if self.max_cpu_loras and self.max_cpu_loras > 0 else None,
-        ) if self.enable_lora else None)
+            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
+            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
+        
+        load_config = LoadConfig(
+            load_format=self.load_format,
+            download_dir=self.download_dir,
+            model_loader_extra_config=self.model_loader_extra_config,
+        )
+
         if self.image_input_type:
             if (not self.image_token_id or not self.image_input_shape
                     or not self.image_feature_size):
@@ -619,6 +638,7 @@ class EngineArgs:
 
         decoding_config = DecodingConfig(
             guided_decoding_backend=self.guided_decoding_backend)
+
         return EngineConfig(model_config=model_config,
                             cache_config=cache_config,
                             parallel_config=parallel_config,
@@ -627,6 +647,7 @@ class EngineArgs:
                             lora_config=lora_config,
                             vision_language_config=vision_language_config,
                             speculative_config=speculative_config,
+                            load_config=load_config,
                             decoding_config=decoding_config)
 
 

+ 77 - 69
aphrodite/engine/async_aphrodite.py

@@ -2,22 +2,23 @@ import asyncio
 import os
 import time
 from functools import partial
-from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
-                    Union, AsyncIterator)
+from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
+                    Optional, Set, Tuple, Type, Union)
+
 from loguru import logger
 from transformers import PreTrainedTokenizer
 
-from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import ModelConfig
-from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import MultiModalData
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
+from aphrodite.lora.request import LoRARequest
 
 ENGINE_ITERATION_TIMEOUT_S = int(
-    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "120"))
+    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "60"))
 
 
 class AsyncEngineDeadError(RuntimeError):
@@ -28,8 +29,7 @@ def _raise_exception_on_finish(
         task: asyncio.Task, error_callback: Callable[[Exception],
                                                      None]) -> None:
     msg = ("Task finished unexpectedly. This should never happen! "
-           "Please open an issue on Github. Include your full error "
-           "log after killing the process with Ctrl+C.")
+           "Please open an issue on Github.")
 
     exception = None
     try:
@@ -54,7 +54,7 @@ class AsyncStream:
 
     def __init__(self, request_id: str) -> None:
         self.request_id = request_id
-        self._queue = asyncio.Queue()
+        self._queue: asyncio.Queue = asyncio.Queue()
         self._finished = False
 
     def put(self, item: Union[RequestOutput, Exception]) -> None:
@@ -322,29 +322,33 @@ class AsyncAphrodite:
         self.max_log_len = max_log_len
         self.engine = self._init_engine(*args, **kwargs)
 
-        self.background_loop = None
+        self.background_loop: Optional[asyncio.Future] = None
         # We need to keep a reference to unshielded
         # task as well to prevent it from being garbage
         # collected
-        self._background_loop_unshielded = None
+        self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
         self.start_engine_loop = start_engine_loop
-        self._request_tracker: Optional[RequestTracker] = None
         self._errored_with: Optional[BaseException] = None
 
+        # Lazy initialized fields
+        self._request_tracker: RequestTracker
+
     @classmethod
-    def from_engine_args(cls,
-                         engine_args: AsyncEngineArgs,
-                         start_engine_loop: bool = True) -> "AsyncAphrodite":
+    def from_engine_args(
+        cls,
+        engine_args: AsyncEngineArgs,
+        start_engine_loop: bool = True,
+    ) -> "AsyncAphrodite":
         """Creates an async LLM engine from the engine arguments."""
         # Create the engine configs.
         engine_config = engine_args.create_engine_config()
 
         if engine_config.device_config.device_type == "neuron":
-            raise NotImplementedError("Neuron is not supported for "
-                                      "async engine yet.")
+            from aphrodite.executor.neuron_executor import NeuronExecutorAsync
+            executor_class = NeuronExecutorAsync
         elif engine_config.device_config.device_type == "cpu":
-            from aphrodite.executor.cpu_executor import CPUExecutor
-            executor_class = CPUExecutor
+            from aphrodite.executor.cpu_executor import CPUExecutorAsync
+            executor_class = CPUExecutorAsync
         elif engine_config.parallel_config.worker_use_ray:
             initialize_ray_cluster(engine_config.parallel_config)
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
@@ -355,24 +359,28 @@ class AsyncAphrodite:
             from aphrodite.executor.gpu_executor import GPUExecutorAsync
             executor_class = GPUExecutorAsync
         # Create the async LLM engine.
-        engine = cls(engine_config.parallel_config.worker_use_ray,
-                     engine_args.engine_use_ray,
-                     **engine_config.to_dict(),
-                     executor_class=executor_class,
-                     log_requests=not engine_args.disable_log_requests,
-                     log_stats=not engine_args.disable_log_stats,
-                     max_log_len=engine_args.max_log_len,
-                     start_engine_loop=start_engine_loop)
+        engine = cls(
+            engine_config.parallel_config.worker_use_ray,
+            engine_args.engine_use_ray,
+            **engine_config.to_dict(),
+            executor_class=executor_class,
+            log_requests=not engine_args.disable_log_requests,
+            log_stats=not engine_args.disable_log_stats,
+            max_log_len=engine_args.max_log_len,
+            start_engine_loop=start_engine_loop,
+        )
         return engine
 
     @property
     def is_running(self) -> bool:
         return (self.background_loop is not None
+                and self._background_loop_unshielded is not None
                 and not self._background_loop_unshielded.done())
 
     @property
     def is_stopped(self) -> bool:
-        return self.errored or (self.background_loop is not None
+        return self.errored or (self.background_loop is not None and
+                                self._background_loop_unshielded is not None
                                 and self._background_loop_unshielded.done())
 
     @property
@@ -388,7 +396,7 @@ class AsyncAphrodite:
 
     async def get_tokenizer(self) -> "PreTrainedTokenizer":
         if self.engine_use_ray:
-            return await self.engine.get_tokenizer.remote()
+            return await self.engine.get_tokenizer.remote()  # type: ignore
         else:
             return self.engine.get_tokenizer()
 
@@ -441,7 +449,8 @@ class AsyncAphrodite:
             # TODO: Maybe add add_request_batch to reduce Ray overhead
             try:
                 if self.engine_use_ray:
-                    await self.engine.add_request.remote(**new_request)
+                    await self.engine.add_request.remote(  # type: ignore
+                        **new_request)
                 else:
                     await self.engine.add_request_async(**new_request)
             except ValueError as e:
@@ -456,7 +465,7 @@ class AsyncAphrodite:
             await self._engine_abort(finished_requests)
 
         if self.engine_use_ray:
-            request_outputs = await self.engine.step.remote()
+            request_outputs = await self.engine.step.remote()  # type: ignore
         else:
             request_outputs = await self.engine.step_async()
 
@@ -469,33 +478,29 @@ class AsyncAphrodite:
 
     async def _engine_abort(self, request_ids: Iterable[str]):
         if self.engine_use_ray:
-            await self.engine.abort_request.remote(request_ids)
+            await self.engine.abort_request.remote(request_ids)  # type: ignore
         else:
             self.engine.abort_request(request_ids)
 
     async def run_engine_loop(self):
         has_requests_in_progress = False
-        try:
-            while True:
-                if not has_requests_in_progress:
-                    logger.debug("Waiting for new requests...")
-                    await self._request_tracker.wait_for_new_requests()
-                    logger.debug("Got new requests!")
-
-                # Abort if iteration takes too long due to unrecoverable errors
-                # (eg. NCCL timeouts).
-                try:
-                    has_requests_in_progress = await asyncio.wait_for(
-                        self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
-                except asyncio.TimeoutError as exc:
-                    logger.error(
-                        "Engine iteration timed out. This should never happen!"
-                    )
-                    self.set_errored(exc)
-                    raise
-                await asyncio.sleep(0)
-        except KeyboardInterrupt:
-            logger.info("Engine loop interrupted. Exiting gracefully.")
+        while True:
+            if not has_requests_in_progress:
+                logger.debug("Waiting for new requests...")
+                await self._request_tracker.wait_for_new_requests()
+                logger.debug("Got new requests!")
+
+            # Abort if iteration takes too long due to unrecoverable errors
+            # (eg. NCCL timeouts).
+            try:
+                has_requests_in_progress = await asyncio.wait_for(
+                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
+            except asyncio.TimeoutError as exc:
+                logger.error(
+                    "Engine iteration timed out. This should never happen!")
+                self.set_errored(exc)
+                raise
+            await asyncio.sleep(0)
 
     async def add_request(
         self,
@@ -505,6 +510,7 @@ class AsyncAphrodite:
         prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> AsyncStream:
         if self.log_requests:
             shortened_prompt = prompt
@@ -518,6 +524,7 @@ class AsyncAphrodite:
             logger.info(f"Received request {request_id}: "
                         f"prompt: {shortened_prompt!r}, "
                         f"sampling_params: {sampling_params}, "
+                        f"prompt_token_ids: {shortened_token_ids}, "
                         f"lora_request: {lora_request}.")
 
         if not self.is_running:
@@ -534,11 +541,12 @@ class AsyncAphrodite:
             arrival_time = time.time()
 
         if self.engine_use_ray:
-            prompt_token_ids = await self.engine.encode_request_async.remote(
-                request_id=request_id,
-                prompt=prompt,
-                prompt_token_ids=prompt_token_ids,
-                lora_request=lora_request)
+            prompt_token_ids = await (
+                self.engine.encode_request_async.remote(  # type: ignore
+                    request_id=request_id,
+                    prompt=prompt,
+                    prompt_token_ids=prompt_token_ids,
+                    lora_request=lora_request))
         else:
             prompt_token_ids = await self.engine.encode_request_async(
                 request_id=request_id,
@@ -552,7 +560,9 @@ class AsyncAphrodite:
             sampling_params=sampling_params,
             prompt_token_ids=prompt_token_ids,
             arrival_time=arrival_time,
-            lora_request=lora_request)
+            lora_request=lora_request,
+            multi_modal_data=multi_modal_data,
+        )
 
         return stream
 
@@ -563,6 +573,7 @@ class AsyncAphrodite:
         request_id: str,
         prompt_token_ids: Optional[List[int]] = None,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None
     ) -> AsyncIterator[RequestOutput]:
         """Generate outputs for a request.
 
@@ -578,6 +589,7 @@ class AsyncAphrodite:
             prompt_token_ids: The token IDs of the prompt. If None, we
                 use the tokenizer to convert the prompts to token IDs.
             lora_request: LoRA request to use for generation, if any.
+            multi_modal_data: Multi modal data per request.
 
         Yields:
             The output `RequestOutput` objects from the AphroditeEngine for the
@@ -628,8 +640,7 @@ class AsyncAphrodite:
             >>> ...
         """
         # Preprocess the request.
-        # This should not be used for logging, as it is monotonic time.
-        arrival_time = time.monotonic()
+        arrival_time = time.time()
 
         try:
             stream = await self.add_request(
@@ -639,14 +650,11 @@ class AsyncAphrodite:
                 prompt_token_ids=prompt_token_ids,
                 arrival_time=arrival_time,
                 lora_request=lora_request,
+                multi_modal_data=multi_modal_data,
             )
 
             async for request_output in stream:
                 yield request_output
-        except asyncio.exceptions.CancelledError:
-            logger.info(f"Request {request_id} cancelled.")
-            self._abort(request_id)
-            raise
         except (Exception, asyncio.CancelledError) as e:
             # If there is an exception or coroutine is cancelled, abort the
             # request.
@@ -686,13 +694,13 @@ class AsyncAphrodite:
     async def get_model_config(self) -> ModelConfig:
         """Get the model configuration of the Aphrodite engine."""
         if self.engine_use_ray:
-            return await self.engine.get_model_config.remote()
+            return await self.engine.get_model_config.remote()  # type: ignore
         else:
             return self.engine.get_model_config()
 
     async def do_log_stats(self) -> None:
         if self.engine_use_ray:
-            await self.engine.do_log_stats.remote()
+            await self.engine.do_log_stats.remote()  # type: ignore
         else:
             self.engine.do_log_stats()
 
@@ -705,7 +713,7 @@ class AsyncAphrodite:
 
         if self.engine_use_ray:
             try:
-                await self.engine.check_health.remote()
+                await self.engine.check_health.remote()  # type: ignore
             except ray.exceptions.RayActorError as e:
                 raise RuntimeError("Engine is dead.") from e
         else:

+ 37 - 68
aphrodite/engine/metrics.py

@@ -1,15 +1,17 @@
 import time
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
 
 import numpy as np
-from loguru import logger
 from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
                                disable_created_metrics)
+from loguru import logger
+
 
 if TYPE_CHECKING:
     from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
+
 disable_created_metrics()
 
 # The begin-* and end* here are used by the documentation generator
@@ -27,97 +29,61 @@ class Metrics:
 
         # Config Information
         self.info_cache_config = Info(
-            name="aphrodite:cache_config",
-            documentation="information of cache_config",
-        )
+            name='aphrodite:cache_config',
+            documentation='information of cache_config')
 
         # System stats
         self.gauge_scheduler_running = Gauge(
             name="aphrodite:num_requests_running",
             documentation="Number of requests currently running on GPU.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.gauge_scheduler_swapped = Gauge(
             name="aphrodite:num_requests_swapped",
             documentation="Number of requests swapped to CPU.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.gauge_scheduler_waiting = Gauge(
             name="aphrodite:num_requests_waiting",
             documentation="Number of requests waiting to be processed.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.gauge_gpu_cache_usage = Gauge(
             name="aphrodite:gpu_cache_usage_perc",
             documentation="GPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.gauge_cpu_cache_usage = Gauge(
             name="aphrodite:cpu_cache_usage_perc",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
 
         # Raw stats from last model iteration
         self.counter_prompt_tokens = Counter(
             name="aphrodite:prompt_tokens_total",
             documentation="Number of prefill tokens processed.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.counter_generation_tokens = Counter(
             name="aphrodite:generation_tokens_total",
             documentation="Number of generation tokens processed.",
-            labelnames=labelnames,
-        )
+            labelnames=labelnames)
         self.histogram_time_to_first_token = Histogram(
             name="aphrodite:time_to_first_token_seconds",
             documentation="Histogram of time to first token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.001,
-                0.005,
-                0.01,
-                0.02,
-                0.04,
-                0.06,
-                0.08,
-                0.1,
-                0.25,
-                0.5,
-                0.75,
-                1.0,
-                2.5,
-                5.0,
-                7.5,
-                10.0,
-            ],
-        )
+                0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
+                0.75, 1.0, 2.5, 5.0, 7.5, 10.0
+            ])
         self.histogram_time_per_output_token = Histogram(
             name="aphrodite:time_per_output_token_seconds",
             documentation="Histogram of time per output token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.01,
-                0.025,
-                0.05,
-                0.075,
-                0.1,
-                0.15,
-                0.2,
-                0.3,
-                0.4,
-                0.5,
-                0.75,
-                1.0,
-                2.5,
-            ],
-        )
+                0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
+                1.0, 2.5
+            ])
         self.histogram_e2e_request_latency = Histogram(
             name="aphrodite:e2e_request_latency_seconds",
             documentation="Histogram of end to end request latency in seconds.",
             labelnames=labelnames,
-            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
-        )
+            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
 
         # Legacy metrics
         self.gauge_avg_prompt_throughput = Gauge(
@@ -138,7 +104,6 @@ class Metrics:
 @dataclass
 class Stats:
     """Created by AphroditeEngine for use by StatLogger."""
-
     now: float
 
     # System stats.
@@ -158,12 +123,18 @@ class Stats:
     spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
 
 
+class SupportsMetricsInfo(Protocol):
+
+    def metrics_info(self) -> Dict[str, str]:
+        ...
+
+
 class StatLogger:
     """StatLogger is used AphroditeEngine to log to Promethus and Stdout."""
 
     def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
         # Metadata for logging locally.
-        self.last_local_log = time.monotonic()
+        self.last_local_log = time.time()
         self.local_interval = local_interval
 
         # Tracked stats over current local logging interval.
@@ -174,7 +145,7 @@ class StatLogger:
         self.labels = labels
         self.metrics = Metrics(labelnames=list(labels.keys()))
 
-    def info(self, type: str, obj: object) -> None:
+    def info(self, type: str, obj: SupportsMetricsInfo) -> None:
         if type == "cache_config":
             self.metrics.info_cache_config.info(obj.metrics_info())
 
@@ -219,9 +190,8 @@ class StatLogger:
                                  generation_throughput: float) -> None:
         # Logs metrics to prometheus that are computed every logging_interval.
         # Support legacy gauge metrics that make throughput calculations on
-        # the Aphrodite side.
-        # Moving forward, we should use counters like counter_prompt_tokens,
-        # counter_generation_tokens
+        # the Aphrodite side. Moving forward, we should use counters like
+        # counter_prompt_tokens, counter_generation_tokens
         # Which log raw data and calculate summaries using rate() on the
         # grafana/prometheus side.
         self.metrics.gauge_avg_prompt_throughput.labels(
@@ -231,8 +201,8 @@ class StatLogger:
 
     def log(self, stats: Stats) -> None:
         """Called by AphroditeEngine.
-        Logs to prometheus and tracked stats every iteration.
-        Logs to Stdout every self.local_interval seconds."""
+           Logs to prometheus and tracked stats every iteration.
+           Logs to Stdout every self.local_interval seconds."""
 
         # Log to prometheus.
         self._log_prometheus(stats)
@@ -243,22 +213,21 @@ class StatLogger:
 
         # Log locally every local_interval seconds.
         if self._local_interval_elapsed(stats.now):
-            # Compute summary metrics for tracked stats (and log them to
-            # prometheus if applicable).
+            # Compute summary metrics for tracked stats (and log them
+            # to promethus if applicable).
             prompt_throughput = self._get_throughput(self.num_prompt_tokens,
                                                      now=stats.now)
             generation_throughput = self._get_throughput(
                 self.num_generation_tokens, now=stats.now)
             self._log_prometheus_interval(
                 prompt_throughput=prompt_throughput,
-                generation_throughput=generation_throughput,
-            )
+                generation_throughput=generation_throughput)
 
             # Log to stdout.
             logger.info(
                 f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
-                f"Avg generation throughput: {generation_throughput:.1f} "
-                "tokens/s, "
+                f"Avg generation throughput: "
+                f"{generation_throughput:.1f} tokens/s, "
                 f"Running: {stats.num_running} reqs, "
                 f"Swapped: {stats.num_swapped} reqs, "
                 f"Pending: {stats.num_waiting} reqs, "

+ 5 - 2
aphrodite/engine/output_processor/interfaces.py

@@ -1,11 +1,12 @@
 from abc import ABC, abstractmethod
-from typing import Callable, Iterable, List
+from typing import Callable, List
 
 from transformers import PreTrainedTokenizer
 
 from aphrodite.common.config import SchedulerConfig
 from aphrodite.common.sequence import (Sequence, SequenceGroup,
                                        SequenceGroupOutput)
+from aphrodite.common.utils import Counter
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -15,6 +16,7 @@ class SequenceGroupOutputProcessor(ABC):
     """Interface for logic that processes new token ids in sequence groups,
     managing detokenization, stop checking, and freeing/forking sequences with
     the scheduler.
+
     This is highly coupled with the LLMEngine and should be seen as an extension
     of it. The logic is separated to simplify the LLMEngine class and allow
     separate implementations for single-step decoding (which supports beam
@@ -27,11 +29,12 @@ class SequenceGroupOutputProcessor(ABC):
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
         scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: "StopChecker",
     ):
         """Create an output processor.
+
         This returns a single-step output processor if num_lookahead_slots is
         zero, else returns a multi-step output processor.
         """

+ 12 - 8
aphrodite/engine/output_processor/multi_step.py

@@ -1,16 +1,17 @@
-from typing import Callable, Iterable, List
+from typing import Callable, List
 
 from transformers import PreTrainedTokenizer
 
-from aphrodite.processing.scheduler import Scheduler
-from aphrodite.engine.output_processor.interfaces import (
-    SequenceGroupOutputProcessor)
+from aphrodite.engine.output_processor.interfaces import \
+    SequenceGroupOutputProcessor
 from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.processing.scheduler import Scheduler
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
-                                       SequenceGroupOutput, SequenceOutput,
-                                       SequenceStatus)
+                                SequenceGroupOutput, SequenceOutput,
+                                SequenceStatus)
 from aphrodite.transformers_utils.detokenizer import Detokenizer
+from aphrodite.common.utils import Counter
 
 
 class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@@ -21,6 +22,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
     This is currently mutually exclusive with advanced sampling techniques like
     beam search, which motivates the separation of this logic from the single
     step output processor.
+
     This class is responsible for things such as correctly appending all new
     token ids to their sequence, detokenizing new token ids, truncating new
     output tokens after an eos token, and correctly handling the case where the
@@ -31,7 +33,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
         self,
         detokenizer: Detokenizer,
         scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: StopChecker,
     ):
@@ -44,8 +46,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
     def process_outputs(self, sequence_group: SequenceGroup,
                         outputs: List[SequenceGroupOutput]) -> None:
         """Append new tokens in the outputs to sequences in the sequence group.
+
         This only supports sequence groups of size 1. It supports greater than
         one new token per sequence.
+
         This applies logic like stop condition checking and detokenization,
         including freeing finished sequences. It also handles cases where there
         are tokens emitted after the EOS token.
@@ -119,4 +123,4 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
                 break
 
         if seq.is_finished():
-            self.scheduler.free_seq(seq)
+            self.scheduler.free_seq(seq)

+ 10 - 6
aphrodite/engine/output_processor/single_step.py

@@ -1,10 +1,11 @@
-from typing import Iterable, List, Tuple, Union
+from typing import Dict, List, Tuple, Union
 
 from aphrodite.common.config import SchedulerConfig
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (Sequence, SequenceGroup,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceStatus)
+from aphrodite.common.utils import Counter
 from aphrodite.engine.output_processor.interfaces import \
     SequenceGroupOutputProcessor
 from aphrodite.engine.output_processor.stop_checker import StopChecker
@@ -18,6 +19,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
     scheduling of the next batch. Output processing logic includes
     detokenization, and determining if a sequence is finished (e.g. via max len
     or eos token).
+
     The SingleStepOutputProcessor is specialized to the case where the model
     emits at most a single token per invocation, which precludes configurations
     such as speculative decoding or multi-step decoding. This enables beam
@@ -30,7 +32,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
         scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        seq_counter: Counter,
         stop_checker: StopChecker,
     ):
         self.scheduler_config = scheduler_config
@@ -43,6 +45,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                         outputs: List[SequenceGroupOutput]) -> None:
         """Append all new tokens to sequences in the sequence group. Fork any
         surviving beam candidates; free any unsurviving ones.
+
         Invokes detokenizer to detokenize new tokens, and also marks sequences
         as finished if they meet stop conditions.
         """
@@ -55,7 +58,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
 
         # Process prompt logprobs
         prompt_logprobs = outputs.prompt_logprobs
-        if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
+        if prompt_logprobs is not None and \
+            seq_group.sampling_params.detokenize and self.detokenizer:
             self.detokenizer.decode_prompt_logprobs_inplace(
                 seq_group, prompt_logprobs)
             seq_group.prompt_logprobs = prompt_logprobs
@@ -64,7 +68,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
         samples = outputs.samples
         parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
         existing_finished_seqs = seq_group.get_finished_seqs()
-        parent_child_dict = {
+        parent_child_dict: Dict[int, List[SequenceOutput]] = {
             parent_seq.seq_id: []
             for parent_seq in parent_seqs
         }
@@ -87,7 +91,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 continue
             # Fork the parent sequence if there are multiple child samples.
             for child_sample in child_samples[:-1]:
-                new_child_seq_id = next(self.seq_counter)
+                new_child_seq_id: int = next(self.seq_counter)
                 child = parent.fork(new_child_seq_id)
                 child.append_token_id(child_sample.output_token,
                                       child_sample.logprobs)
@@ -101,7 +105,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             child_seqs.append((parent, parent))
 
         for seq, _ in child_seqs:
-            if seq_group.sampling_params.detokenize:
+            if seq_group.sampling_params.detokenize and self.detokenizer:
                 new_char_count = self.detokenizer.decode_sequence_inplace(
                     seq, seq_group.sampling_params)
             else:

+ 2 - 0
aphrodite/engine/output_processor/stop_checker.py

@@ -22,6 +22,7 @@ class StopChecker:
     def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
                             sampling_params: SamplingParams) -> None:
         """Stop the finished sequences.
+
        new_char_count is the number of chars added to the
            sequence's output text for the newly generated token
         """
@@ -72,6 +73,7 @@ class StopChecker:
                             sampling_params: SamplingParams) -> Optional[str]:
         """Check if any stop strings are matched and truncate sequence
         output text accordingly.
+
         Returns the stop string if matched or else None.
         """
         if not new_char_count:

+ 3 - 1
aphrodite/engine/output_processor/util.py

@@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
     """Helper method which transforms a 2d list organized by
     [step][sequence group] into [sequence group][step].
     """
-    output_by_sequence_group = [[] for _ in range(num_seq_groups)]
+    output_by_sequence_group: List[List[SamplerOutput]] = [
+        [] for _ in range(num_seq_groups)
+    ]
     for step in sampler_outputs:
         for i, sequence_group_output in enumerate(step):
             output_by_sequence_group[i].append(sequence_group_output)

+ 10 - 31
aphrodite/engine/ray_tools.py

@@ -1,47 +1,28 @@
 import pickle
+from typing import List, Optional, Tuple
 
-from typing import Optional, List, Tuple
 from loguru import logger
 
 from aphrodite.common.config import ParallelConfig
-from aphrodite.common.utils import is_hip, set_cuda_visible_devices, get_ip
+from aphrodite.common.utils import get_ip, is_hip
+from aphrodite.task_handler.worker_base import WorkerWrapperBase
+
 
 try:
     import ray
 
-    class RayWorkerAphrodite:
+    class RayWorkerWrapper(WorkerWrapperBase):
         """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 
-        def __init__(self, init_cached_hf_modules=False) -> None:
-            if init_cached_hf_modules:
-                from transformers.dynamic_module_utils import init_hf_modules
-                init_hf_modules()
-            self.worker = None
+        def __init__(self, *args, **kwargs) -> None:
+            super().__init__(*args, **kwargs)
             # Since the compiled DAG runs a main execution
             # in a different thread that calls cuda.set_device.
             # The flag indicates is set_device is called on
             # that thread.
             self.compiled_dag_cuda_device_set = False
 
-        def init_worker(self, worker_init_fn):
-            self.worker = worker_init_fn()
-
-        def __getattr__(self, name):
-            return getattr(self.worker, name)
-
-        def execute_method(self, method, *args, **kwargs):
-            try:
-                executor = getattr(self, method)
-                return executor(*args, **kwargs)
-            except Exception as e:
-                # exceptions in ray worker may cause deadlock
-                # print the error and inform the user to solve the error
-                msg = (f"Error executing method {method}. "
-                       "This might cause deadlock in distributed execution.")
-                logger.exception(msg)
-                raise e
-
         def get_node_ip(self) -> str:
             return get_ip()
 
@@ -50,9 +31,6 @@ try:
             gpu_ids = ray.get_gpu_ids()
             return node_id, gpu_ids
 
-        def set_cuda_visible_devices(self, device_ids) -> None:
-            set_cuda_visible_devices(device_ids)
-
         def execute_model_compiled_dag_remote(self, ignored):
             """Used only when compiled DAG is enabled."""
             import torch
@@ -68,8 +46,8 @@ except ImportError as e:
     logger.warning(f"Failed to import Ray with {e!r}. "
                    "For distributed inference, please install Ray with "
                    "`pip install ray`.")
-    ray = None
-    RayWorkerAphrodite = None
+    ray = None  # type: ignore
+    RayWorkerWrapper = None  # type: ignore
 
 
 def initialize_ray_cluster(
@@ -77,6 +55,7 @@ def initialize_ray_cluster(
     ray_address: Optional[str] = None,
 ):
     """Initialize the distributed cluster with Ray.
+
     it will connect to the Ray cluster and create a placement group
     for the workers, which includes the specification of the resources
     for each distributed worker.

+ 28 - 20
aphrodite/executor/cpu_executor.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, List, Set
+from typing import Dict, List, Set, Tuple
 
 import torch
 from loguru import logger
@@ -8,7 +8,7 @@ from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
-from aphrodite.executor.executor_base import ExecutorBase
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 
 
@@ -28,8 +28,8 @@ class CPUExecutor(ExecutorBase):
     def _init_worker(self):
         from aphrodite.task_handler.cpu_worker import CPUWorker
 
-        assert (self.parallel_config.world_size == 1
-                ), "CPUExecutor only supports single CPU socket currently."
+        assert self.parallel_config.world_size == 1, (
+            "CPUExecutor only supports single CPU socket currently.")
 
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
@@ -39,17 +39,19 @@ class CPUExecutor(ExecutorBase):
             scheduler_config=self.scheduler_config,
             device_config=self.device_config,
             cache_config=self.cache_config,
+            load_config=self.load_config,
             local_rank=0,
             rank=0,
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
+            vision_language_config=self.vision_language_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
             is_driver_worker=True,
         )
         self.driver_worker.init_device()
         self.driver_worker.load_model()
 
-    def determine_num_available_blocks(self) -> tuple[int, int]:
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks by invoking the
         underlying worker.
         """
@@ -62,7 +64,6 @@ class CPUExecutor(ExecutorBase):
         # NOTE: We log here to avoid multiple logs when number of workers is
         # greater than one. We could log in the engine, but not all executors
         # have GPUs.
-
         # NOTE: `cpu block` for CPU backend is located on CPU memory but is
         # referred as `gpu block`. Because we want to reuse the existing block
         # management procedure.
@@ -84,6 +85,23 @@ class CPUExecutor(ExecutorBase):
         )
         return output
 
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        return self.driver_worker.add_lora(lora_request)
+
+    def remove_lora(self, lora_id: int) -> bool:
+        return self.driver_worker.remove_lora(lora_id)
+
+    def list_loras(self) -> Set[int]:
+        return self.driver_worker.list_loras()
+
+    def check_health(self) -> None:
+        # CPUExecutor will always be healthy as long as
+        # it's running.
+        return
+
+
+class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
+
     async def execute_model_async(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -97,20 +115,10 @@ class CPUExecutor(ExecutorBase):
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
             blocks_to_copy=blocks_to_copy,
-            num_lookahead_slots=num_lookahead_slots,
-        )
+            num_lookahead_slots=num_lookahead_slots)
         return output
 
-    def add_lora(self, lora_request: LoRARequest) -> bool:
-        return self.driver_worker.add_lora(lora_request)
-
-    def remove_lora(self, lora_id: int) -> bool:
-        return self.driver_worker.remove_lora(lora_id)
-
-    def list_loras(self) -> Set[int]:
-        return self.driver_worker.list_loras()
-
-    def check_health(self) -> None:
+    async def check_health_async(self) -> None:
         # CPUExecutor will always be healthy as long as
         # it's running.
         return
@@ -139,9 +147,9 @@ def _verify_and_get_scheduler_config(
 
 def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
     _GB = 1 << 30
-    if config.context_shift:
+    if config.enable_prefix_caching:
         logger.warning("Prefix caching is not supported on CPU, disable it.")
-        config.context_shift = False
+        config.enable_prefix_caching = False
 
     kv_cache_space_str = os.getenv("APHRODITE_CPU_KVCACHE_SPACE", "0")
     kv_cache_space = int(kv_cache_space_str)

+ 17 - 14
aphrodite/executor/executor_base.py

@@ -1,16 +1,17 @@
 from abc import ABC, abstractmethod
 from typing import Dict, List, Optional, Set, Tuple
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig, VisionLanguageConfig,
-                                     SpeculativeConfig)
-from aphrodite.lora.request import LoRARequest
+from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
+                                     SchedulerConfig, SpeculativeConfig,
+                                     VisionLanguageConfig)
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.lora.request import LoRARequest
 
 
 class ExecutorBase(ABC):
     """Base class for all executors.
+
     An executor is responsible for executing the model on a specific device
     type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
     that can execute the model on multiple devices.
@@ -23,6 +24,7 @@ class ExecutorBase(ABC):
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
+        load_config: LoadConfig,
         lora_config: Optional[LoRAConfig],
         vision_language_config: Optional[VisionLanguageConfig],
         speculative_config: Optional[SpeculativeConfig],
@@ -30,6 +32,7 @@ class ExecutorBase(ABC):
         self.model_config = model_config
         self.cache_config = cache_config
         self.lora_config = lora_config
+        self.load_config = load_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
         self.device_config = device_config
@@ -46,9 +49,11 @@ class ExecutorBase(ABC):
     def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available blocks for the GPU KV cache and
         swappable CPU KV cache.
+
         Normally, this should simply delegate to the underlying Worker. Some
         ExecutorBase may require modification of the result, e.g. to ensure the
         selected cache sizes are compatible with all workers.
+
         Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
         are blocks that are "active" on the device and can be appended to.
         num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
@@ -64,15 +69,13 @@ class ExecutorBase(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def execute_model(
-        self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
-        blocks_to_swap_in: Dict[int, int],
-        blocks_to_swap_out: Dict[int, int],
-        blocks_to_copy: Dict[int, List[int]],
-        num_lookahead_slots: int,
-    ) -> List[SamplerOutput]:
-        """Executes one model step on the given sequences."""
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]],
+                      num_lookahead_slots: int) -> List[SamplerOutput]:
+        """Executes at least one model step on the given sequences."""
         raise NotImplementedError
 
     @abstractmethod

+ 30 - 17
aphrodite/executor/gpu_executor.py

@@ -1,22 +1,19 @@
-from typing import Dict, List, Set
+from typing import Dict, List, Set, Tuple
 
 from loguru import logger
 
-from aphrodite.lora.request import LoRARequest
-from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (
-    get_ip,
-    get_open_port,
-    get_distributed_init_method,
-    make_async,
-)
+from aphrodite.common.utils import (get_distributed_init_method, get_ip,
+                                    get_open_port, make_async)
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from aphrodite.lora.request import LoRARequest
 
 
 class GPUExecutor(ExecutorBase):
 
     def _init_executor(self) -> None:
         """Initialize the worker and load the model.
+
         If speculative decoding is enabled, we instead create the speculative
         worker.
         """
@@ -30,8 +27,8 @@ class GPUExecutor(ExecutorBase):
         # before CUDA_VISIBLE_DEVICES is set in the Worker
         from aphrodite.task_handler.worker import Worker
 
-        assert (self.parallel_config.world_size == 1
-                ), "GPUExecutor only supports single GPU."
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
 
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
@@ -41,6 +38,7 @@ class GPUExecutor(ExecutorBase):
             scheduler_config=self.scheduler_config,
             device_config=self.device_config,
             cache_config=self.cache_config,
+            load_config=self.load_config,
             local_rank=0,
             rank=0,
             distributed_init_method=distributed_init_method,
@@ -56,6 +54,7 @@ class GPUExecutor(ExecutorBase):
         """
         assert self.speculative_config is not None
 
+        from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
         from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
         from aphrodite.task_handler.worker import Worker
 
@@ -68,6 +67,7 @@ class GPUExecutor(ExecutorBase):
             scheduler_config=self.scheduler_config,
             device_config=self.device_config,
             cache_config=self.cache_config,
+            load_config=self.load_config,
             local_rank=0,
             rank=0,
             distributed_init_method=distributed_init_method,
@@ -76,11 +76,25 @@ class GPUExecutor(ExecutorBase):
             is_driver_worker=True,
         )
 
-        spec_decode_worker = SpecDecodeWorker.create_worker(
-            scorer_worker=target_worker,
-            speculative_config=self.speculative_config,
+        draft_worker = MultiStepWorker(
+            model_config=self.speculative_config.draft_model_config,
+            parallel_config=self.speculative_config.draft_parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
+            # TODO allow draft-model specific load config.
+            load_config=self.load_config,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            vision_language_config=self.vision_language_config,
+            is_driver_worker=True,
         )
 
+        spec_decode_worker = SpecDecodeWorker.from_workers(
+            proposer_worker=draft_worker, scorer_worker=target_worker)
+
         assert self.parallel_config.world_size == 1, (
             "GPUExecutor only supports single GPU.")
 
@@ -89,7 +103,7 @@ class GPUExecutor(ExecutorBase):
         # Load model handled in spec decode worker.
         self.driver_worker.init_device()
 
-    def determine_num_available_blocks(self) -> tuple[int, int]:
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks by invoking the
         underlying worker.
         """
@@ -159,6 +173,5 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
             blocks_to_copy=blocks_to_copy,
-            num_lookahead_slots=num_lookahead_slots,
-        )
+            num_lookahead_slots=num_lookahead_slots)
         return output

+ 24 - 3
aphrodite/executor/neuron_executor.py

@@ -1,8 +1,10 @@
-from typing import Dict, List, Set
+from typing import Dict, List, Set, Tuple
 
+
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
-from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import make_async
 
 
 class NeuronExecutor(ExecutorBase):
@@ -29,7 +31,7 @@ class NeuronExecutor(ExecutorBase):
         self.driver_worker.init_device()
         self.driver_worker.load_model()
 
-    def determine_num_available_blocks(self) -> tuple[int, int]:
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks by invoking the
         underlying worker.
         """
@@ -70,3 +72,22 @@ class NeuronExecutor(ExecutorBase):
         # NeuronExecutor will always be healthy as long as
         # it's running.
         return
+
+
+class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        output = await make_async(self.driver_worker.execute_model)(
+            seq_group_metadata_list=seq_group_metadata_list, )
+        return output
+
+    async def check_health_async(self) -> None:
+        # NeuronExecutor will always be healthy as long as
+        # it's running.
+        return

+ 107 - 87
aphrodite/executor/ray_gpu_executor.py

@@ -1,19 +1,19 @@
 import asyncio
-import copy
-from collections import defaultdict
 import os
 import pickle
-from typing import TYPE_CHECKING, Any, Dict, List, Set, Optional
+from collections import defaultdict
+from itertools import islice, repeat
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
 
 from loguru import logger
 
-from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (get_aphrodite_instance_id,
+                                    get_distributed_init_method, get_ip,
+                                    get_open_port, make_async)
+from aphrodite.engine.ray_tools import RayWorkerWrapper, ray
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
-from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
-                                    get_open_port, get_distributed_init_method,
-                                    make_async)
 
 if ray is not None:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -21,6 +21,7 @@ if ray is not None:
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
 
+
 # If the env var is set, it uses the Ray's compiled DAG API
 # which optimizes the control plane overhead.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
@@ -74,9 +75,9 @@ class RayGPUExecutor(ExecutorBase):
 
         # The driver dummy worker does not actually use any resources.
         # It holds the resource for the driver worker.
-        self.driver_dummy_worker: RayWorkerAphrodite = None
+        self.driver_dummy_worker: RayWorkerWrapper = None
         # The remaining workers are the actual ray actors.
-        self.workers: List[RayWorkerAphrodite] = []
+        self.workers: List[RayWorkerWrapper] = []
 
         if self.parallel_config.ray_workers_use_nsight:
             ray_remote_kwargs = self._configure_ray_workers_use_nsight(
@@ -97,13 +98,22 @@ class RayGPUExecutor(ExecutorBase):
                 num_gpus=num_gpus,
                 scheduling_strategy=scheduling_strategy,
                 **ray_remote_kwargs,
-            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
+            )(RayWorkerWrapper).remote(
+                worker_module_name="aphrodite.task_handler.worker",
+                worker_class_name="Worker",
+                trust_remote_code=self.model_config.trust_remote_code,
+            )
 
             worker_ip = ray.get(worker.get_node_ip.remote())
             if worker_ip == driver_ip and self.driver_dummy_worker is None:
                 # If the worker is on the same node as the driver, we use it
                 # as the resource holder for the driver process.
                 self.driver_dummy_worker = worker
+                self.driver_worker = RayWorkerWrapper(
+                    worker_module_name="aphrodite.task_handler.worker",
+                    worker_class_name="Worker",
+                    trust_remote_code=self.model_config.trust_remote_code,
+                )
             else:
                 # Else, added to the list of workers.
                 self.workers.append(worker)
@@ -115,79 +125,59 @@ class RayGPUExecutor(ExecutorBase):
                 "GPU node.")
 
         # Get the set of GPU IDs used on each node.
-        driver_node_id, driver_gpu_ids = ray.get(
-            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
-        worker_node_and_gpu_ids = ray.get(
-            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+        worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
+                                                    use_dummy_driver=True)
 
         node_workers = defaultdict(list)
         node_gpus = defaultdict(list)
 
-        node_workers[driver_node_id].append(0)
-        node_gpus[driver_node_id].extend(driver_gpu_ids)
-        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
-                                               start=1):
+        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
             node_workers[node_id].append(i)
             node_gpus[node_id].extend(gpu_ids)
         for node_id, gpu_ids in node_gpus.items():
             node_gpus[node_id] = sorted(gpu_ids)
 
-        # Set CUDA_VISIBLE_DEVICES for the driver and workers.
-        set_cuda_visible_devices(node_gpus[driver_node_id])
-        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
-            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+        APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
+
+        # Set environment variables for the driver and workers.
+        all_args_to_update_environment_variables = [({
+            "CUDA_VISIBLE_DEVICES":
+            ",".join(map(str, node_gpus[node_id])),
+            "APHRODITE_INSTANCE_ID":
+            APHRODITE_INSTANCE_ID,
+            "APHRODITE_TRACE_FUNCTION":
+            os.getenv("APHRODITE_TRACE_FUNCTION", "0"),
+        }, ) for (node_id, _) in worker_node_and_gpu_ids]
+        self._run_workers("update_environment_variables",
+                          all_args=all_args_to_update_environment_variables)
 
         distributed_init_method = get_distributed_init_method(
             driver_ip, get_open_port())
 
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker
-
-        model_config = copy.deepcopy(self.model_config)
-        parallel_config = copy.deepcopy(self.parallel_config)
-        scheduler_config = copy.deepcopy(self.scheduler_config)
-        device_config = copy.deepcopy(self.device_config)
-        lora_config = copy.deepcopy(self.lora_config)
-        cache_config = copy.deepcopy(self.cache_config)
-        vision_language_config = copy.deepcopy(self.vision_language_config)
-
-        # Initialize the actual workers with the Worker class.
-        for rank, (worker, (node_id, _)) in enumerate(
-                zip(self.workers, worker_node_and_gpu_ids),
-                start=1,
-        ):
+        def collect_arg_helper_func(**kwargs):
+            # avoid writing `{"name": value}` manually
+            return kwargs
+
+        # Initialize the actual workers inside worker wrapper.
+        init_worker_all_kwargs = []
+        for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
             local_rank = node_workers[node_id].index(rank)
-            worker.init_worker.remote(
-                lambda rank=rank, local_rank=local_rank: Worker(
-                    model_config=model_config,
-                    parallel_config=parallel_config,
-                    scheduler_config=scheduler_config,
-                    device_config=device_config,
-                    cache_config=cache_config,
+            init_worker_all_kwargs.append(
+                collect_arg_helper_func(
+                    model_config=self.model_config,
+                    parallel_config=self.parallel_config,
+                    scheduler_config=self.scheduler_config,
+                    device_config=self.device_config,
+                    cache_config=self.cache_config,
+                    load_config=self.load_config,
                     local_rank=local_rank,
                     rank=rank,
                     distributed_init_method=distributed_init_method,
-                    lora_config=lora_config,
-                    vision_language_config=vision_language_config,
+                    lora_config=self.lora_config,
+                    vision_language_config=self.vision_language_config,
+                    is_driver_worker=rank == 0,
                 ))
-
-        # Initialize the driver worker with the Worker class.
-        driver_rank = 0
-        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
-        self.driver_worker = Worker(
-            model_config=self.model_config,
-            parallel_config=self.parallel_config,
-            scheduler_config=self.scheduler_config,
-            device_config=self.device_config,
-            cache_config=self.cache_config,
-            local_rank=driver_local_rank,
-            rank=driver_rank,
-            distributed_init_method=distributed_init_method,
-            lora_config=self.lora_config,
-            vision_language_config=self.vision_language_config,
-            is_driver_worker=True,
-        )
+        self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
 
         self._run_workers("init_device")
         self._run_workers(
@@ -196,13 +186,15 @@ class RayGPUExecutor(ExecutorBase):
             max_parallel_loading_workers,
         )
 
-    def determine_num_available_blocks(self) -> tuple[int, int]:
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks.
+
         This invokes `determine_num_available_blocks` on each worker and takes
         the min of the results, guaranteeing that the selected cache sizes are
         compatible with all workers.
+
         Returns:
-            - tuple[num_gpu_blocks, num_cpu_blocks]
+            - Tuple[num_gpu_blocks, num_cpu_blocks]
         """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         num_blocks = self._run_workers("determine_num_available_blocks", )
@@ -212,6 +204,7 @@ class RayGPUExecutor(ExecutorBase):
         # operators can be applied to all workers.
         num_gpu_blocks = min(b[0] for b in num_blocks)
         num_cpu_blocks = min(b[1] for b in num_blocks)
+
         return num_gpu_blocks, num_cpu_blocks
 
     def initialize_cache(self, num_gpu_blocks: int,
@@ -241,7 +234,7 @@ class RayGPUExecutor(ExecutorBase):
                       blocks_to_swap_in: Dict[int, int],
                       blocks_to_swap_out: Dict[int, int],
                       blocks_to_copy: Dict[int, List[int]],
-                      num_lookahead_slots: int = 0) -> List[SamplerOutput]:
+                      num_lookahead_slots: int = 0) -> SamplerOutput:
         all_outputs = self._run_workers(
             "execute_model",
             driver_kwargs={
@@ -277,38 +270,62 @@ class RayGPUExecutor(ExecutorBase):
         self,
         method: str,
         *args,
-        driver_args: Optional[List[Any]] = None,
+        driver_args: Optional[Tuple[Any, ...]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
+        all_args: Optional[List[Tuple[Any, ...]]] = None,
+        all_kwargs: Optional[List[Dict[str, Any]]] = None,
+        use_dummy_driver: bool = False,
         max_concurrent_workers: Optional[int] = None,
         use_ray_compiled_dag: bool = False,
         **kwargs,
     ) -> Any:
-        """Runs the given method on all workers."""
+        """Runs the given method on all workers. Can be used in the following
+        ways:
+
+        - args/kwargs: All workers share the same args/kwargs
+        - args/kwargs and driver_args/driver_kwargs: Driver worker has
+          different args
+        - all_args/all_kwargs: args/kwargs for each worker are specified
+          individually
+        """
 
         if max_concurrent_workers:
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
 
+        if driver_args is None:
+            driver_args = args if all_args is None else all_args[0]
+        if driver_kwargs is None:
+            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
+
+        count = len(self.workers)
+        all_worker_args = repeat(args, count) if all_args is None \
+            else islice(all_args, 1, None)
+        all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
+            else islice(all_kwargs, 1, None)
+
         if use_ray_compiled_dag:
             # Right now, compiled DAG can only accept a single
             # input. TODO: Fix it.
+            assert self.forward_dag is not None
             output_channels = self.forward_dag.execute(1)
         else:
             # Start the ray workers first.
             ray_worker_outputs = [
-                worker.execute_method.remote(method, *args, **kwargs)
-                for worker in self.workers
+                worker.execute_method.remote(method, *worker_args,
+                                             **worker_kwargs)
+                for (worker, worker_args, worker_kwargs
+                     ) in zip(self.workers, all_worker_args, all_worker_kwargs)
             ]
 
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
         # Start the driver worker after all the ray workers.
-        driver_worker_output = getattr(self.driver_worker,
-                                       method)(*driver_args, **driver_kwargs)
-
+        if not use_dummy_driver:
+            driver_worker_output = self.driver_worker.execute_method(
+                method, *driver_args, **driver_kwargs)
+        else:
+            driver_worker_output = ray.get(
+                self.driver_dummy_worker.execute_method.remote(
+                    method, *driver_args, **driver_kwargs))
         # Get the results of the ray workers.
         if self.workers:
             if use_ray_compiled_dag:
@@ -334,7 +351,7 @@ class RayGPUExecutor(ExecutorBase):
             raise ValueError(f"Ray version {required_version} or greater is "
                              f"required, but found {current_version}")
 
-        from ray.dag import MultiOutputNode, InputNode
+        from ray.dag import InputNode, MultiOutputNode
         assert self.parallel_config.worker_use_ray
 
         # Right now, compiled DAG requires at least 1 arg. We send
@@ -366,11 +383,15 @@ class RayGPUExecutor(ExecutorBase):
 
 class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.driver_executor = make_async(self.driver_worker.execute_method)
+
     async def _run_workers_async(
         self,
         method: str,
         *args,
-        driver_args: Optional[List[Any]] = None,
+        driver_args: Optional[Tuple[Any, ...]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> Any:
@@ -382,9 +403,8 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
         if driver_kwargs is None:
             driver_kwargs = kwargs
 
-        # Run the driver worker asynchronously.
-        driver_executor = make_async(getattr(self.driver_worker, method))
-        coros.append(driver_executor(*driver_args, **driver_kwargs))
+        coros.append(
+            self.driver_executor(method, *driver_args, **driver_kwargs))
 
         # Run the ray workers asynchronously.
         for worker in self.workers:
@@ -399,7 +419,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
-        num_lookahead_slots: int = 0,
+        num_lookahead_slots: int,
     ) -> SamplerOutput:
         all_outputs = await self._run_workers_async(
             "execute_model",

+ 1 - 1
aphrodite/modeling/guided_decoding/outlines_logits_processors.py

@@ -1,6 +1,6 @@
 # Copyright 2024- the Outlines developers
 # This file is adapted from
-# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
+# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/aphrodite.py
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 0 - 456
aphrodite/modeling/hf_downloader.py

@@ -1,456 +0,0 @@
-"""Utilities for downloading and initializing model weights."""
-import fnmatch
-import glob
-import json
-import os
-from collections import defaultdict
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
-
-import filelock
-import huggingface_hub.constants
-import numpy as np
-import torch
-from huggingface_hub import HfFileSystem, snapshot_download
-from loguru import logger
-from safetensors.torch import load_file, safe_open, save_file
-from tqdm.auto import tqdm
-from transformers import PretrainedConfig, AutoModelForCausalLM
-
-from aphrodite.common.config import ModelConfig
-from aphrodite.quantization.gguf_utils import (GGUFReader, get_tensor_name_map,
-                                               MODEL_ARCH_NAMES)
-from aphrodite.common.logger import get_loading_progress_bar
-from aphrodite.quantization import (QuantizationConfig,
-                                    get_quantization_config)
-from aphrodite.quantization.schema import QuantParamSchema
-
-_xdg_cache_home = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
-_aphrodite_filelocks_path = os.path.join(_xdg_cache_home, "aphrodite/locks/")
-
-
-def enable_hf_transfer():
-    """automatically activates hf_transfer
-    """
-    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
-        try:
-            # enable hf hub transfer if available
-            import hf_transfer  # type: ignore # noqa
-            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
-        except ImportError:
-            pass
-
-
-enable_hf_transfer()
-
-
-class Disabledtqdm(tqdm):
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs, disable=True)
-
-
-def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
-    lock_dir = cache_dir if cache_dir is not None else _aphrodite_filelocks_path
-    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
-    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
-    lock = filelock.SoftFileLock(os.path.join(lock_dir, lock_file_name))
-    return lock
-
-
-def _shared_pointers(tensors):
-    ptrs = defaultdict(list)
-    for k, v in tensors.items():
-        ptrs[v.data_ptr()].append(k)
-    failing = []
-    for _, names in ptrs.items():
-        if len(names) > 1:
-            failing.append(names)
-    return failing
-
-
-def convert_bin_to_safetensor_file(
-    pt_filename: str,
-    sf_filename: str,
-) -> None:
-    loaded = torch.load(pt_filename, map_location="cpu")
-    if "state_dict" in loaded:
-        loaded = loaded["state_dict"]
-    shared = _shared_pointers(loaded)
-    for shared_weights in shared:
-        for name in shared_weights[1:]:
-            loaded.pop(name)
-
-    # For tensors to be contiguous
-    loaded = {k: v.contiguous() for k, v in loaded.items()}
-
-    dirname = os.path.dirname(sf_filename)
-    os.makedirs(dirname, exist_ok=True)
-    save_file(loaded, sf_filename, metadata={"format": "pt"})
-
-    # check file size
-    sf_size = os.stat(sf_filename).st_size
-    pt_size = os.stat(pt_filename).st_size
-    if (sf_size - pt_size) / pt_size > 0.01:
-        raise RuntimeError(f"""The file size different is more than 1%:
-         - {sf_filename}: {sf_size}
-         - {pt_filename}: {pt_size}
-         """)
-
-    # check if the tensors are the same
-    reloaded = load_file(sf_filename)
-    for k in loaded:
-        pt_tensor = loaded[k]
-        sf_tensor = reloaded[k]
-        if not torch.equal(pt_tensor, sf_tensor):
-            raise RuntimeError(f"The output tensors do not match for key {k}")
-
-
-# TODO: Move this to another place.
-def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
-    quant_cls = get_quantization_config(model_config.quantization)
-    # Read the quantization config from the HF model config, if available.
-    # if the quantization if "gguf", we skip and return quant_cls()
-    if model_config.quantization in ["exl2", "gguf"]:
-        return quant_cls()
-    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
-                              None)
-    if hf_quant_config is not None:
-        return quant_cls.from_config(hf_quant_config)
-    model_name_or_path = model_config.model
-    is_local = os.path.isdir(model_name_or_path)
-    if not is_local:
-        # Download the config files.
-        with get_lock(model_name_or_path, model_config.download_dir):
-            hf_folder = snapshot_download(
-                model_name_or_path,
-                revision=model_config.revision,
-                allow_patterns="*.json",
-                cache_dir=model_config.download_dir,
-                tqdm_class=Disabledtqdm,
-            )
-    else:
-        hf_folder = model_name_or_path
-    config_files = glob.glob(os.path.join(hf_folder, "*.json"))
-
-    quant_config_files = [
-        f for f in config_files if any(
-            f.endswith(x) for x in quant_cls.get_config_filenames())
-    ]
-    if len(quant_config_files) == 0:
-        raise ValueError(
-            f"Cannot find the config file for {model_config.quantization}")
-    if len(quant_config_files) > 1:
-        raise ValueError(
-            f"Found multiple config files for {model_config.quantization}: "
-            f"{quant_config_files}")
-
-    quant_config_file = quant_config_files[0]
-    with open(quant_config_file, "r") as f:
-        config = json.load(f)
-    return quant_cls.from_config(config)
-
-
-def prepare_hf_model_weights(
-    model_name_or_path: str,
-    cache_dir: Optional[str] = None,
-    load_format: str = "auto",
-    fall_back_to_pt: bool = True,
-    revision: Optional[str] = None,
-) -> Tuple[str, List[str], bool]:
-    # Download model weights from huggingface.
-    is_local = os.path.isdir(model_name_or_path)
-    use_safetensors = False
-    # Some quantized models use .pt files for storing the weights.
-    if load_format == "auto":
-        allow_patterns = ["*.safetensors", "*.bin"]
-    elif load_format == "safetensors":
-        use_safetensors = True
-        allow_patterns = ["*.safetensors"]
-    elif load_format == "pt":
-        allow_patterns = ["*.pt"]
-    elif load_format == "npcache":
-        allow_patterns = ["*.bin"]
-    else:
-        raise ValueError(f"Unknown load_format: {load_format}")
-
-    if fall_back_to_pt:
-        allow_patterns += ["*.pt"]
-
-    if not is_local:
-        # Before we download we look at that is available:
-        fs = HfFileSystem()
-        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
-
-        # depending on what is available we download different things
-        for pattern in allow_patterns:
-            matching = fnmatch.filter(file_list, pattern)
-            if len(matching) > 0:
-                allow_patterns = [pattern]
-                break
-
-        logger.info(f"Using model weights format {allow_patterns}")
-        # Use file lock to prevent multiple processes from
-        # downloading the same model weights at the same time.
-        with get_lock(model_name_or_path, cache_dir):
-            hf_folder = snapshot_download(
-                model_name_or_path,
-                allow_patterns=allow_patterns,
-                cache_dir=cache_dir,
-                tqdm_class=Disabledtqdm,
-                revision=revision,
-            )
-    else:
-        hf_folder = model_name_or_path
-    hf_weights_files: List[str] = []
-    for pattern in allow_patterns:
-        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
-        if len(hf_weights_files) > 0:
-            if pattern == "*.safetensors":
-                use_safetensors = True
-            break
-    if not use_safetensors:
-        # Exclude files that are not needed for inference.
-        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
-        blacklist = [
-            "training_args.bin",
-            "optimizer.bin",
-            "optimizer.pt",
-            "scheduler.pt",
-            "scaler.pt",
-            "trainer_state.json",
-            "hidden_states.safetensors",  # exllamav2
-        ]
-        hf_weights_files = [
-            f for f in hf_weights_files
-            if not any(f.endswith(x) for x in blacklist)
-        ]
-
-    if len(hf_weights_files) == 0:
-        raise RuntimeError(
-            f"Cannot find any model weights with `{model_name_or_path}`")
-
-    return hf_folder, hf_weights_files, use_safetensors
-
-
-def convert_gguf_to_state_dict(checkpoint, config):
-    model_type = config.model_type
-    # hack: ggufs have a different name than transformers
-    if model_type == "cohere":
-        model_type = "command-r"
-    arch = None
-    for key, value in MODEL_ARCH_NAMES.items():
-        if value == model_type:
-            arch = key
-            break
-    if arch is None:
-        raise RuntimeError(f"Unknown model_type: {model_type}")
-    num_layers = config.num_hidden_layers
-    name_map = get_tensor_name_map(arch, num_layers)
-    with torch.device("meta"):
-        dummy_model = AutoModelForCausalLM.from_config(config)
-    state_dict = dummy_model.state_dict()
-
-    gguf_to_hf_name_map = {}
-    keys_to_remove = []
-    for hf_name in state_dict:
-        name, suffix = hf_name.rsplit(".", 1)
-        gguf_name = name_map.get_name(name)
-        if gguf_name:
-            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
-        elif name == "lm_head":
-            keys_to_remove.append(hf_name)
-            logger.warning(
-                f"GGUF tensor name for {hf_name} not found, "
-                "this is normal if the model uses tie word embeddings.")
-        else:
-            logger.warning(
-                f"GGUF tensor name for {hf_name} in hf state_dict not found.")
-    for key in keys_to_remove:
-        state_dict.pop(key)
-
-    if os.path.isfile(checkpoint):
-        results = [GGUFReader(checkpoint)]
-    elif os.path.isdir(checkpoint):
-        results = [
-            GGUFReader(os.path.join(checkpoint, file))
-            for file in os.listdir(checkpoint)
-            if os.path.splitext(file)[-1].lower() == ".gguf"
-        ]
-    else:
-        raise RuntimeError(
-            f"Cannot find any model weights with `{checkpoint}`")
-
-    with get_loading_progress_bar() as progress:
-        task = progress.add_task(
-            "[cyan]Converting GGUF tensors to PyTorch...",
-            total=sum([len(result.tensors) for result in results]),
-        )
-        for result in results:
-            for ts in result.tensors:
-                try:
-                    hf_name = gguf_to_hf_name_map[ts.name]
-                except KeyError:
-                    logger.warning(
-                        f"hf tensor name for {ts.name} in GGUF not found.")
-                    continue
-                data = torch.tensor(ts.data)
-                if state_dict[hf_name].dim() == 2:
-                    data = data.view(state_dict[hf_name].shape[0], -1)
-                state_dict[hf_name] = data
-                weight_type = torch.tensor(int(ts.tensor_type),
-                                           dtype=torch.int)
-                if weight_type > 1:
-                    state_dict[hf_name.replace("weight",
-                                               "weight_type")] = weight_type
-                progress.update(task, advance=1)
-    return state_dict
-
-
-def hf_model_weights_iterator(
-    model_name_or_path: str,
-    cache_dir: Optional[str] = None,
-    load_format: str = "auto",
-    revision: Optional[str] = None,
-    config: Optional[PretrainedConfig] = None,
-    fall_back_to_pt: Optional[bool] = True,
-) -> Iterator[Tuple[str, torch.Tensor]]:
-    if model_name_or_path.endswith("gguf"):
-        for name, param in convert_gguf_to_state_dict(model_name_or_path,
-                                                      config).items():
-            yield name, param
-        return
-
-    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
-        model_name_or_path,
-        cache_dir=cache_dir,
-        load_format=load_format,
-        fall_back_to_pt=fall_back_to_pt,
-        revision=revision,
-    )
-
-    if load_format == "npcache":
-        # Currently np_cache only support *.bin checkpoints
-        assert use_safetensors is False
-
-        # Convert the model weights from torch tensors to numpy arrays for
-        # faster loading.
-        np_folder = os.path.join(hf_folder, "np")
-        os.makedirs(np_folder, exist_ok=True)
-        weight_names_file = os.path.join(np_folder, "weight_names.json")
-        # Use file lock to prevent multiple processes from
-        # dumping the same model weights to numpy at the same time.
-        with get_lock(model_name_or_path, cache_dir):
-            if not os.path.exists(weight_names_file):
-                weight_names = []
-                for bin_file in hf_weights_files:
-                    state = torch.load(bin_file, map_location="cpu")
-                    for name, param in state.items():
-                        param_path = os.path.join(np_folder, name)
-                        with open(param_path, "wb") as f:
-                            np.save(f, param.cpu().detach().numpy())
-                        weight_names.append(name)
-                with open(weight_names_file, "w") as f:
-                    json.dump(weight_names, f)
-
-        with open(weight_names_file, "r") as f:
-            weight_names = json.load(f)
-
-        for name in weight_names:
-            param_path = os.path.join(np_folder, name)
-            with open(param_path, "rb") as f:
-                param = np.load(f)
-            yield name, torch.from_numpy(param)
-    elif use_safetensors:
-        for st_file in hf_weights_files:
-            with safe_open(st_file, framework="pt") as f:
-                for name in f.keys():  # noqa: SIM118
-                    param = f.get_tensor(name)
-                    yield name, param
-    else:
-        for bin_file in hf_weights_files:
-            state = torch.load(bin_file, map_location="cpu")
-            for name, param in state.items():
-                yield name, param
-            del state
-            torch.cuda.empty_cache()
-
-
-def kv_cache_scales_loader(
-        filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
-        model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
-    """
-    A simple utility to read in KV cache scaling factors that have been
-    previously serialized to disk. Used by the model to populate the appropriate
-    KV cache scaling factors. The serialization should represent a dictionary
-    whose keys are the TP ranks and values are another dictionary mapping layers
-    to their KV cache scaling factors.
-    Keep this function in sync with the output of examples/fp8/extract_scales.py
-    """
-    try:
-        with open(filename) as f:
-            context = {
-                "model_type": model_type,
-                "num_hidden_layers": num_hidden_layers,
-                "tp_rank": tp_rank,
-                "tp_size": tp_size,
-            }
-            schema_dct = json.load(f)
-            schema = QuantParamSchema.model_validate(schema_dct,
-                                                     context=context)
-            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
-            return layer_scales_map.items()
-
-    except FileNotFoundError:
-        logger.error(f"File or directory '{filename}' not found.")
-    except json.JSONDecodeError:
-        logger.error(f"Error decoding JSON in file '{filename}'.")
-    except Exception as e:
-        logger.error(f"An error occurred while reading '{filename}': {e}")
-    # This section is reached if and only if any of the excepts are hit
-    # Return an empty iterable (list) => no KV cache scales are loaded
-    # which ultimately defaults to 1.0 scales
-    logger.warning("Defaulting to KV cache scaling factors = 1.0 "
-                   f"for all layers in TP rank {tp_rank} "
-                   "as an error occurred during loading.")
-    return []
-
-
-def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
-    """convert PySafeSlice object from safetensors to torch.Tensor
-
-    PySafeSlice object supports indexing, which is done before loading the
-    actual tensor and can reduce the amount of memory being read into the
-    memory. However, it does not support more advanced functionalities
-    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
-    tensor with these more complicated operators, we need to convert to
-    tensor first.
-    """
-    if not isinstance(x, torch.Tensor):
-        x = x[:]
-    return x
-
-
-def default_weight_loader(param: torch.Tensor,
-                          loaded_weight: torch.Tensor) -> None:
-    """Default weight loader."""
-    if isinstance(param, torch.nn.parameter.UninitializedParameter):
-        param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
-    assert param.size() == loaded_weight.size()
-    param.data.copy_(loaded_weight)
-
-
-def initialize_dummy_weights(
-    model: torch.nn.Module,
-    low: float = -1e-3,
-    high: float = 1e-3,
-) -> None:
-    """Initialize model weights with random values.
-
-    The model weights must be randomly initialized for accurate performance
-    measurements. Additionally, the model weights should not cause NaNs in the
-    forward pass. We empirically found that initializing the weights with
-    values between -1e-3 and 1e-3 works well for most models.
-    """
-    for param in model.state_dict().values():
-        if torch.is_floating_point(param):
-            param.data.uniform_(low, high)

+ 8 - 9
aphrodite/modeling/layers/activation.py

@@ -7,13 +7,10 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 from aphrodite._C import ops
-from aphrodite.distributed import (
-    divide,
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-)
-from aphrodite.quantization import QuantizationConfig
+from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization import QuantizationConfig
 
 
 class SiluAndMul(nn.Module):
@@ -22,8 +19,8 @@ class SiluAndMul(nn.Module):
     The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
 
     Shapes:
-        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
-        return: (batch_size, seq_len, d) or (num_tokens, d)
+        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
+        return: (num_tokens, d) or (batch_size, seq_len, d)
     """
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -41,7 +38,9 @@ class SiluAndMul(nn.Module):
 
 class GeluAndMul(nn.Module):
     """An activation function for GeGLU.
+
     The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
+
     Shapes:
         x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
         return: (batch_size, seq_len, d) or (num_tokens, d)
@@ -51,7 +50,7 @@ class GeluAndMul(nn.Module):
         super().__init__()
         self.approximate = approximate
         if approximate not in ("none", "tanh"):
-            raise ValueError(f"Invalid approximate value: {approximate}")
+            raise ValueError(f"Unknown approximate mode: {approximate}")
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
         """PyTorch-native implementation equivalent to forward()."""

+ 3 - 2
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -1,6 +1,7 @@
 from aphrodite.modeling.layers.fused_moe.fused_moe import (
-    fused_moe, moe_align_block_size, fused_topk, get_config_file_name)
+    fused_moe, get_config_file_name)
 
 __all__ = [
-    "fused_moe", "moe_align_block_size", "fused_topk", "get_config_file_name"
+    "fused_moe",
+    "get_config_file_name",
 ]

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 128 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json

@@ -0,0 +1,128 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_stages": 1
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_stages": 1
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_stages": 1
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 0
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "512": {
+        "BLOCK_SIZE_M": 256,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    }
+}

+ 138 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,138 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 2
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 3
+    }
+}

+ 110 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json

@@ -0,0 +1,110 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 32,
+        "GROUP_SIZE_M": 32
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8
+    },
+    "48": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32
+    },
+    "256": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 128 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json

@@ -0,0 +1,128 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 32,
+        "GROUP_SIZE_M": 32,
+        "num_stages": 1
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "16": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 1
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 0
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 1
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_stages": 0
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 0
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 128 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json

@@ -0,0 +1,128 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 1
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_stages": 1
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 0
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_stages": 0
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 1
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "512": {
+        "BLOCK_SIZE_M": 256,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 8,
+        "num_stages": 0
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 256,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 32,
+        "GROUP_SIZE_M": 1,
+        "num_stages": 0
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 3
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "2": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "4": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 5
+    },
+    "16": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 5
+    },
+    "512": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 3
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 108 - 67
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -19,6 +19,8 @@ def fused_moe_kernel(
     a_ptr,
     b_ptr,
     c_ptr,
+    a_scale_ptr,
+    b_scale_ptr,
     topk_weights_ptr,
     sorted_token_ids_ptr,
     expert_ids_ptr,
@@ -47,6 +49,7 @@ def fused_moe_kernel(
     MUL_ROUTED_WEIGHT: tl.constexpr,
     top_k: tl.constexpr,
     compute_type: tl.constexpr,
+    use_fp8: tl.constexpr,
 ):
     """
     Implements the fused computation for a Mixture of Experts (MOE) using
@@ -109,6 +112,10 @@ def fused_moe_kernel(
     b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
                                                 offs_bn[None, :] * stride_bn)
 
+    if use_fp8:
+        a_scale = tl.load(a_scale_ptr)
+        b_scale = tl.load(b_scale_ptr + off_experts)
+
     # -----------------------------------------------------------
     # Iterate to compute a block of the C matrix.
     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@@ -127,7 +134,10 @@ def fused_moe_kernel(
                     mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
                     other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        if use_fp8:
+            accumulator = tl.dot(a, b, acc=accumulator)
+        else:
+            accumulator += tl.dot(a, b)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -138,7 +148,10 @@ def fused_moe_kernel(
                              other=0)
         accumulator = accumulator * moe_weight[:, None]
 
-    accumulator = accumulator.to(compute_type)
+    if use_fp8:
+        accumulator = (accumulator * a_scale * b_scale).to(compute_type)
+    else:
+        accumulator = accumulator.to(compute_type)
     # -----------------------------------------------------------
     # Write back the block of the output
     offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -205,15 +218,24 @@ def moe_align_block_size(
 
 
 def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
-                            topk_weights: torch.Tensor, topk_ids: torch.Tensor,
+                            B_scale: torch.Tensor, topk_weights: torch.Tensor,
+                            topk_ids: torch.Tensor,
                             sorted_token_ids: torch.Tensor,
                             expert_ids: torch.Tensor,
                             num_tokens_post_padded: torch.Tensor,
                             mul_routed_weight: bool, top_k: int,
-                            config: Dict[str, Any]) -> None:
+                            config: Dict[str, Any], compute_type: tl.dtype,
+                            use_fp8: bool) -> None:
     assert topk_weights.stride(1) == 1
     assert sorted_token_ids.stride(0) == 1
 
+    if not use_fp8:
+        A_scale = None
+        assert B_scale is None
+    else:
+        A, A_scale = ops.scaled_fp8_quant(A)
+        assert B_scale is not None
+
     grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
         'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
 
@@ -221,6 +243,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
         A,
         B,
         C,
+        A_scale,
+        B_scale,
         topk_weights,
         sorted_token_ids,
         expert_ids,
@@ -238,65 +262,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
         C.stride(2),
         MUL_ROUTED_WEIGHT=mul_routed_weight,
         top_k=top_k,
-        compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
+        compute_type=compute_type,
+        use_fp8=use_fp8,
         **config,
     )
 
 
-def fused_topk(
-    gating_output: torch.Tensor,
-    topk: int,
-    renormalize: bool,
-):
-    """Compute top-k indice and weights from gating logits
-
-    Args:
-        gating_output (torch.Tensor): The output of the gating operation
-          (before softmax).
-        topk (int): The number of top-k experts to select.
-        renormalize (bool): If True, renormalize the top-k weights to sum to 1.
-    """
-    M = gating_output.shape[0]
-    if is_hip():
-        # The MoE kernels are not yet supported on ROCm.
-        routing_weights = torch.softmax(gating_output,
-                                        dim=-1,
-                                        dtype=torch.float32)
-        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
-    else:
-        import aphrodite._moe_C as moe_kernels
-
-        topk_weights = torch.empty(M,
-                                   topk,
-                                   dtype=torch.float32,
-                                   device=gating_output.device)
-        topk_ids = torch.empty(M,
-                               topk,
-                               dtype=torch.int32,
-                               device=gating_output.device)
-        token_expert_indicies = torch.empty(M,
-                                            topk,
-                                            dtype=torch.int32,
-                                            device=gating_output.device)
-        moe_kernels.topk_softmax(
-            topk_weights,
-            topk_ids,
-            token_expert_indicies,
-            gating_output.float(),  # TODO(woosuk): Optimize this.
-        )
-        del token_expert_indicies  # Not used. Will be used in the future.
-    if renormalize:
-        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
-    return topk_weights, topk_ids
-
-
-def get_config_file_name(E: int, N: int) -> str:
+def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
     device_name = torch.cuda.get_device_name().replace(" ", "_")
-    return f"E={E},N={N},device_name={device_name}.json"
+    dtype_selector = "" if not dtype else f",dtype={dtype}"
+    return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 @functools.lru_cache
-def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
+def get_moe_configs(E: int, N: int,
+                    dtype: Optional[str]) -> Optional[Dict[int, Any]]:
     """
     Return optimized configurations for the fused MoE kernel.
 
@@ -308,7 +288,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
 
     # First look up if an optimized configuration is available in the configs
     # directory
-    json_file_name = get_config_file_name(E, N)
+    json_file_name = get_config_file_name(E, N, dtype)
 
     config_file_path = os.path.join(
         os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
@@ -331,8 +311,11 @@ def fused_moe(
     gating_output: torch.Tensor,
     topk: int,
     renormalize: bool,
-    inplace: bool = True,
+    inplace: bool = False,
     override_config: Optional[Dict[str, Any]] = None,
+    use_fp8: bool = False,
+    w1_scale: Optional[torch.Tensor] = None,
+    w2_scale: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
     """
     This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -350,6 +333,12 @@ def fused_moe(
         Defaults to False.
     - override_config (Optional[Dict[str, Any]]): Optional override
         for the kernel configuration.
+    - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
+        products for w1 and w2. Defaults to False.
+    - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
+        w1.
+    - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
+        w2.
 
     Returns:
     - torch.Tensor: The output tensor after applying the MoE layer.
@@ -360,19 +349,51 @@ def fused_moe(
     assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
     assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
+    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
     assert hidden_states.dtype in [
         torch.float32, torch.float16, torch.bfloat16
     ]
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
 
-    topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
+    if is_hip():
+        # The MoE kernels are not yet supported on ROCm.
+        routing_weights = torch.softmax(gating_output,
+                                        dim=-1,
+                                        dtype=torch.float32)
+        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
+    else:
+        import aphrodite._moe_C as moe_kernels
+
+        topk_weights = torch.empty(M,
+                                   topk,
+                                   dtype=torch.float32,
+                                   device=hidden_states.device)
+        topk_ids = torch.empty(M,
+                               topk,
+                               dtype=torch.int32,
+                               device=hidden_states.device)
+        token_expert_indicies = torch.empty(M,
+                                            topk,
+                                            dtype=torch.int32,
+                                            device=hidden_states.device)
+        moe_kernels.topk_softmax(
+            topk_weights,
+            topk_ids,
+            token_expert_indicies,
+            gating_output.float(),  # TODO(woosuk): Optimize this.
+        )
+        del token_expert_indicies  # Not used. Will be used in the future.
+    if renormalize:
+        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
     if override_config:
         config = override_config
     else:
         # First try to load optimal config from the file
-        configs = get_moe_configs(E, w2.shape[2])
+        configs = get_moe_configs(E, w2.shape[2],
+                                  "float8" if use_fp8 else None)
 
         if configs:
             # If an optimal configuration map has been found, look up the
@@ -408,17 +429,37 @@ def fused_moe(
     sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
         topk_ids, config['BLOCK_SIZE_M'], E)
 
-    invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
-                            topk_weights, topk_ids, sorted_token_ids,
-                            expert_ids, num_tokens_post_padded, False,
-                            topk_ids.shape[1], config)
+    invoke_fused_moe_kernel(hidden_states,
+                            w1,
+                            intermediate_cache1,
+                            w1_scale,
+                            topk_weights,
+                            topk_ids,
+                            sorted_token_ids,
+                            expert_ids,
+                            num_tokens_post_padded,
+                            False,
+                            topk_ids.shape[1],
+                            config,
+                            compute_type=tl.float16,
+                            use_fp8=use_fp8)
 
     ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
-    invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
-                            topk_weights, topk_ids, sorted_token_ids,
-                            expert_ids, num_tokens_post_padded, True, 1,
-                            config)
+    invoke_fused_moe_kernel(intermediate_cache2,
+                            w2,
+                            intermediate_cache3,
+                            w2_scale,
+                            topk_weights,
+                            topk_ids,
+                            sorted_token_ids,
+                            expert_ids,
+                            num_tokens_post_padded,
+                            True,
+                            1,
+                            config,
+                            compute_type=tl.float16,
+                            use_fp8=use_fp8)
 
     if inplace:
         return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),

+ 79 - 242
aphrodite/modeling/layers/linear.py

@@ -1,20 +1,17 @@
 from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
 
 import torch
 import torch.nn.functional as F
 from loguru import logger
+from torch import nn
 from torch.nn.parameter import Parameter
 
-from aphrodite.distributed import (
-    divide,
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-    split_tensor_along_last_dim,
-    tensor_model_parallel_all_gather,
-    tensor_model_parallel_all_reduce,
-)
-from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   split_tensor_along_last_dim,
+                                   tensor_model_parallel_all_gather,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.utils import set_weight_attrs
 
 
@@ -26,118 +23,46 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
     return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
 
 
-# Split the exl2 weight at group boundary
-def post_init_exl2(layer):
-    q_groups = layer.q_groups
-    num_groups = q_groups.shape[0] // 2
-    input_size = layer.q_invperm.shape[0]
-    tp_rank = get_tensor_model_parallel_rank()
-    tp_size = get_tensor_model_parallel_world_size()
-    rows = 0
-    # row_number, qrow_number, group_number
-    splits = [(0, 0, 0)]
-    index = 1
-    for i in range(num_groups - 1):
-        bits = q_groups[i * 2].item()
-        qrows = (q_groups[i * 2 + 3] - q_groups[i * 2 + 1]).item()
-        rows += qrows * 32 // bits
-
-        q_rows_end = q_groups[i * 2 + 3].item()
-        if q_rows_end >= layer.q_weight.shape[0] // tp_size * index:
-            splits.append((rows, q_rows_end, i + 1))
-            index += 1
-    splits.append((input_size, layer.q_weight.shape[0], num_groups))
-
-    shard_qweight = torch.nn.Parameter(
-        layer.q_weight[splits[tp_rank][1]:splits[tp_rank + 1][1]].clone(),
-        requires_grad=False,
-    )
-    # the outer model.named_parameters() during loading still keeps a reference
-    # so we have to force release memory
-    layer.q_weight.data = torch.empty(0, device="cpu")
-    del layer.q_weight
-    layer.linear_weights["q_weight"] = shard_qweight.to(layer.device)
-
-    shard_qgroups = torch.nn.Parameter(
-        layer.q_groups[splits[tp_rank][2] * 2:splits[tp_rank + 1][2] *
-                       2].clone(),
-        requires_grad=False,
-    )
-    shard_qgroups[1::2] -= int(shard_qgroups[1])
-    layer.q_groups.data = torch.empty(0, device="cpu")
-    del layer.q_groups
-    layer.linear_weights["q_groups"] = shard_qgroups.to(layer.device)
-
-    shard_qscale = torch.nn.Parameter(
-        layer.q_scale[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
-        requires_grad=False,
-    )
-    layer.q_scale.data = torch.empty(0, device="cpu")
-    del layer.q_scale
-    layer.linear_weights["q_scale"] = shard_qscale.to(layer.device)
-
-    shard_qscale_max = torch.nn.Parameter(
-        layer.q_scale_max[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
-        requires_grad=False,
-    )
-    layer.q_scale_max.data = torch.empty(0, device="cpu")
-    del layer.q_scale_max
-    layer.linear_weights["q_scale_max"] = shard_qscale_max.to(layer.device)
-
-    q_perm = torch.argsort(layer.q_invperm).to(torch.short)
-    # q_invperm is not used later, probably we should remove it too
-    layer.linear_weights["q_perm"] = q_perm[
-        splits[tp_rank][0]:splits[tp_rank + 1][0]].to(layer.device)
-
-
 class LinearMethodBase(ABC):
     """Base class for different (maybe quantized) linear methods."""
 
     @abstractmethod
-    def create_weights(self, input_size_per_partition: int,
+    def create_weights(self, layer: torch.nn.Module,
+                       input_size_per_partition: int,
                        output_partition_sizes: List[int], input_size: int,
-                       output_size: int,
-                       params_dtype: torch.dtype) -> Dict[str, Any]:
-        """Create weights for a linear layer."""
+                       output_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
+        """Create weights for a linear layer. 
+           The weights will be set as attributes of the layer.
+        
+        Args:
+            layer: The layer that is using the LinearMethodBase factory.
+            input_size_per_partition: Size of the weight input dim on rank X.
+            output_partition_sizes: Sizes of the output dim of each logical 
+                weight on rank X. E.g., output_partition_sizes for QKVLinear
+                is a list contains the width of Wq, Wk, Wv on rank X.
+            input_size: Size of the input dim of the weight across all ranks.
+            output_size: Size of the output dim of the weight across all ranks.
+            params_dtype: Datatype of the parameters.
+        """
         raise NotImplementedError
 
     @abstractmethod
     def apply_weights(self,
-                      weights: Dict[str, torch.Tensor],
+                      layer: torch.nn.Module,
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        """Apply the weights to the input tensor."""
+        """Apply the weights in layer to the input tensor.
+
+        Expects create_weights to have been called before on the layer."""
         raise NotImplementedError
 
-    def create_moe_weights(self, num_experts: int,
-                           input_size_per_partition: int,
-                           output_partition_sizes: List[int], input_size: int,
-                           output_size: int,
-                           params_dtype: torch.dtype) -> Dict[str, Any]:
-        """Creating moe weights"""
-        linear_weights = self.create_weights(input_size_per_partition,
-                                             output_partition_sizes,
-                                             input_size, output_size,
-                                             params_dtype)
-        if num_experts == 1:
-            return linear_weights
-        for name, param in tuple(linear_weights.items()):
-            if isinstance(param, Parameter):
-                repeat_size = (num_experts, ) + (1, ) * param.dim()
-                new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size),
-                                      requires_grad=False)
-                set_weight_attrs(new_param, param.__dict__)
-                linear_weights[name] = new_param
-        return linear_weights
+    def process_weights_after_loading(self, layer: nn.Module) -> None:
+        """Process the weight after loading.
 
-    @abstractmethod
-    def apply_moe_weights(self, w1: Dict[str,
-                                         torch.Tensor], w2: Dict[str,
-                                                                 torch.Tensor],
-                          x: torch.Tensor, gating_output: torch.Tensor,
-                          topk: int, renormalize: bool) -> torch.Tensor:
-        """Apply the weights to the input tensor."""
-        raise NotImplementedError
+        This can be used for example, to transpose weights for computation.
+        """
+        return
 
 
 class UnquantizedLinearMethod(LinearMethodBase):
@@ -151,42 +76,31 @@ class UnquantizedLinearMethod(LinearMethodBase):
     def __init__(self, separate_bias_add: bool = False):
         self.separate_bias_add = separate_bias_add
 
-    def create_weights(self, input_size_per_partition: int,
+    def create_weights(self, layer: torch.nn.Module,
+                       input_size_per_partition: int,
                        output_partition_sizes: List[int], input_size: int,
-                       output_size: int,
-                       params_dtype: torch.dtype) -> Dict[str, Any]:
+                       output_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
         output_size_per_partition = sum(output_partition_sizes)
         weight = Parameter(torch.empty(output_size_per_partition,
                                        input_size_per_partition,
                                        dtype=params_dtype),
                            requires_grad=False)
         set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
-        return {"weight": weight}
+        layer.register_parameter("weight", weight)
+        set_weight_attrs(weight, extra_weight_attrs)
 
     def apply_weights(self,
-                      weights: Dict[str, torch.Tensor],
+                      layer: torch.nn.Module,
                       x: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        weight = weights["weight"]
+        weight = layer.weight
         if self.separate_bias_add:
             if bias is not None:
                 return F.linear(x, weight) + bias
             return F.linear(x, weight)
         return F.linear(x, weight, bias)
 
-    def apply_embedding(self, weights: Dict[str, torch.Tensor],
-                        x: torch.Tensor) -> torch.Tensor:
-        weight = weights["weight"]
-        return F.embedding(x, weight)
-
-    def apply_moe_weights(self, w1: Dict[str,
-                                         torch.Tensor], w2: Dict[str,
-                                                                 torch.Tensor],
-                          x: torch.Tensor, gating_output: torch.Tensor,
-                          topk: int, renormalize: bool) -> torch.Tensor:
-        return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk,
-                         renormalize)
-
 
 class ReplicatedLinear(torch.nn.Module):
     """Replicated linear layer.
@@ -221,12 +135,9 @@ class ReplicatedLinear(torch.nn.Module):
         if linear_method is None:
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
-        self.linear_weights = self.linear_method.create_weights(
-            self.input_size, [self.output_size], self.input_size,
-            self.output_size, self.params_dtype)
-        for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.nn.parameter.Parameter):
-                self.register_parameter(name, weight)
+        self.linear_method.create_weights(self, self.input_size,
+                                          [self.output_size], self.input_size,
+                                          self.output_size, self.params_dtype)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size, dtype=self.params_dtype))
@@ -236,7 +147,7 @@ class ReplicatedLinear(torch.nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         bias = self.bias if not self.skip_bias_add else None
-        output = self.linear_method.apply_weights(self.linear_weights, x, bias)
+        output = self.linear_method.apply_weights(self, x, bias)
         output_bias = self.bias if self.skip_bias_add else None
         return output, output_bias
 
@@ -259,9 +170,8 @@ class ColumnParallelLinear(torch.nn.Module):
                        skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
         linear_method: (Maybe quantized) linear method.
-        output_sizes: list of output sizes packed into one output, like for
-                      QKV the list would be size 3.
-        num_experts: number of experts for sparse moe models.
+        output_sizes: list of output sizes packed into one output, like for QKV
+                       the list would be size 3.
     """
 
     def __init__(
@@ -274,7 +184,6 @@ class ColumnParallelLinear(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         linear_method: Optional[LinearMethodBase] = None,
         output_sizes: Optional[List[int]] = None,
-        num_experts: int = 1,
     ):
         super().__init__()
 
@@ -294,14 +203,13 @@ class ColumnParallelLinear(torch.nn.Module):
         if output_sizes is None:
             output_sizes = [output_size]
         self.linear_method = linear_method
-        self.num_experts = num_experts
-        self.linear_weights = self.linear_method.create_moe_weights(
-            num_experts, self.input_size, [x // tp_size for x in output_sizes],
-            self.input_size, self.output_size, self.params_dtype)
-        for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.nn.parameter.Parameter):
-                self.register_parameter(name, weight)
-                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+        self.linear_method.create_weights(self,
+                                          self.input_size,
+                                          [x // tp_size for x in output_sizes],
+                                          self.input_size,
+                                          self.output_size,
+                                          self.params_dtype,
+                                          weight_loader=self.weight_loader)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size_per_partition,
@@ -313,32 +221,15 @@ class ColumnParallelLinear(torch.nn.Module):
         else:
             self.register_parameter("bias", None)
 
-    def weight_loader(self,
-                      param: Parameter,
-                      loaded_weight: torch.Tensor,
-                      expert_id: int = -1):
+    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         tp_rank = get_tensor_model_parallel_rank()
-        tp_size = get_tensor_model_parallel_world_size()
         output_dim = getattr(param, "output_dim", None)
         param_data = param.data
-        if self.num_experts > 1:
-            if expert_id >= 0:
-                param_data = param_data[expert_id]
-            # Loaded weight is packed at expert dim
-            else:
-                output_dim = output_dim + 1
-
         if output_dim is not None:
-            if loaded_weight.shape[output_dim] % tp_size != 0:
-                raise ValueError(
-                    "Size is not aligned with the quantized weight shape")
-            shard_size = loaded_weight.shape[output_dim] // tp_size
+            shard_size = param_data.shape[output_dim]
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
-        if isinstance(param, torch.nn.parameter.UninitializedParameter):
-            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
-            param_data = param.data
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
@@ -346,8 +237,7 @@ class ColumnParallelLinear(torch.nn.Module):
         bias = self.bias if not self.skip_bias_add else None
 
         # Matrix multiply.
-        output_parallel = self.linear_method.apply_weights(
-            self.linear_weights, input_, bias)
+        output_parallel = self.linear_method.apply_weights(self, input_, bias)
         if self.gather_output:
             # All-gather across the partitions.
             output = tensor_model_parallel_all_gather(output_parallel)
@@ -387,29 +277,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
         linear_method: Optional[LinearMethodBase] = None,
-        num_experts: int = 1,
     ):
         self.output_sizes = output_sizes
         tp_size = get_tensor_model_parallel_world_size()
         assert all(output_size % tp_size == 0 for output_size in output_sizes)
         super().__init__(input_size, sum(output_sizes), bias, gather_output,
                          skip_bias_add, params_dtype, linear_method,
-                         self.output_sizes, num_experts)
+                         self.output_sizes)
 
     def weight_loader(self,
                       param: Parameter,
                       loaded_weight: torch.Tensor,
-                      loaded_shard_id: Optional[int] = None,
-                      expert_id: int = -1):
+                      loaded_shard_id: Optional[int] = None):
+
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)
         is_metadata = getattr(param, "is_metadata", False)
-        if self.num_experts > 1:
-            if expert_id >= 0:
-                param_data = param_data[expert_id]
-            # Loaded weight is packed at expert dim
-            elif output_dim is not None:
-                output_dim = output_dim + 1
         if loaded_shard_id is None:
             # Loaded weight is already packed.
             if output_dim is None:
@@ -532,12 +415,14 @@ class QKVParallelLinear(ColumnParallelLinear):
         input_size = self.hidden_size
         output_size = (self.num_heads +
                        2 * self.num_kv_heads) * tp_size * self.head_size
+        output_sizes = [
+            self.num_heads * tp_size * self.head_size,
+            self.num_kv_heads * tp_size * self.head_size,
+            self.num_kv_heads * tp_size * self.head_size
+        ]
+
         super().__init__(input_size, output_size, bias, False, skip_bias_add,
-                         params_dtype, linear_method, [
-                             self.num_heads * tp_size * self.head_size,
-                             self.num_kv_heads * tp_size * self.head_size,
-                             self.num_kv_heads * tp_size * self.head_size
-                         ])
+                         params_dtype, linear_method, output_sizes)
 
     def weight_loader(self,
                       param: Parameter,
@@ -666,7 +551,6 @@ class RowParallelLinear(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         reduce_results: bool = True,
         linear_method: Optional[LinearMethodBase] = None,
-        num_experts: int = 1,
     ):
         super().__init__()
         # Keep input parameters
@@ -685,14 +569,13 @@ class RowParallelLinear(torch.nn.Module):
         if linear_method is None:
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
-        self.num_experts = num_experts
-        self.linear_weights = self.linear_method.create_moe_weights(
-            num_experts, self.input_size_per_partition, [self.output_size],
-            self.input_size, self.output_size, self.params_dtype)
-        for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.nn.parameter.Parameter):
-                self.register_parameter(name, weight)
-                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+        self.linear_method.create_weights(self,
+                                          self.input_size_per_partition,
+                                          [self.output_size],
+                                          self.input_size,
+                                          self.output_size,
+                                          self.params_dtype,
+                                          weight_loader=self.weight_loader)
 
         if not reduce_results and (bias and not skip_bias_add):
             raise ValueError("When not reduce the results, adding bias to the "
@@ -708,67 +591,21 @@ class RowParallelLinear(torch.nn.Module):
         else:
             self.register_parameter("bias", None)
 
-        if self.is_exl2():
-            self.uninitialized = [
-                "q_invperm", "q_weight", "q_scale", "q_scale_max", "q_groups"
-            ]
-
-    def is_exl2(self) -> bool:
-        return hasattr(
-            self.linear_method, "quant_config"
-        ) and self.linear_method.quant_config.get_name() == "exl2"
-
-    def weight_loader(self,
-                      param: Parameter,
-                      loaded_weight: torch.Tensor,
-                      expert_id: int = -1):
+    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         tp_rank = get_tensor_model_parallel_rank()
-        tp_size = get_tensor_model_parallel_world_size()
         input_dim = getattr(param, "input_dim", None)
         param_data = param.data
-        if self.num_experts > 1:
-            if expert_id >= 0:
-                param_data = param_data[expert_id]
-            # Loaded weight is packed at expert dim
-            elif input_dim is not None:
-                input_dim = input_dim + 1
         if input_dim is not None:
-            if loaded_weight.shape[input_dim] % tp_size != 0:
-                raise ValueError(
-                    "Size is not aligned with the quantized weight shape")
-
-            shard_size = loaded_weight.shape[input_dim] // tp_size
+            shard_size = param_data.shape[input_dim]
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                  shard_size)
-        if isinstance(param, torch.nn.parameter.UninitializedParameter):
-            if self.is_exl2():
-                self.device = param.device
-                param.materialize(loaded_weight.shape,
-                                  dtype=loaded_weight.dtype,
-                                  device="cpu")
-            else:
-                param.materialize(loaded_weight.shape,
-                                  dtype=loaded_weight.dtype)
-            param_data = param.data
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
-        if self.is_exl2():
-            weights = self.linear_weights
-            self.uninitialized = [
-                key for key in self.uninitialized if isinstance(
-                    weights[key], torch.nn.parameter.UninitializedParameter)
-            ]
-            if not self.uninitialized:
-                # shard and release memory ASAP to reduce peak memory usage
-                post_init_exl2(self)
-
     def forward(self, input_):
-        is_exl2 = self.is_exl2()
-        if self.input_is_parallel and is_exl2:
-            input_parallel = tensor_model_parallel_all_gather(input_)
-        elif self.input_is_parallel or is_exl2:
+        # Set up backprop all-reduce.
+        if self.input_is_parallel:
             input_parallel = input_
         else:
             tp_rank = get_tensor_model_parallel_rank()
@@ -778,7 +615,7 @@ class RowParallelLinear(torch.nn.Module):
 
         # Matrix multiply.
         output_parallel = self.linear_method.apply_weights(
-            self.linear_weights, input_parallel)
+            self, input_parallel)
         if self.reduce_results and self.tp_size > 1:
             output_ = tensor_model_parallel_all_reduce(output_parallel)
         else:

+ 8 - 9
aphrodite/modeling/layers/logits_processor.py

@@ -32,14 +32,11 @@ class LogitsProcessor(nn.Module):
         # Whether the input is logits (default is hidden states).
         self.logits_as_input = logits_as_input
         # original vocabulary size (without LoRA).
-        if org_vocab_size is not None:
-            self.org_vocab_size = min(org_vocab_size, vocab_size)
-        else:
-            self.org_vocab_size = vocab_size
+        self.org_vocab_size = org_vocab_size or vocab_size
 
     def forward(
         self,
-        lm_head: nn.Module,
+        embedding: torch.Tensor,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         embedding_bias: Optional[torch.Tensor] = None,
@@ -49,8 +46,9 @@ class LogitsProcessor(nn.Module):
         else:
             hidden_states = _prune_hidden_states(hidden_states,
                                                  sampling_metadata)
+
             # Get the logits for the next tokens.
-            logits = self._get_logits(hidden_states, lm_head, embedding_bias)
+            logits = self._get_logits(hidden_states, embedding, embedding_bias)
 
         if logits is not None:
             logits *= self.scale
@@ -60,10 +58,10 @@ class LogitsProcessor(nn.Module):
 
         return logits
 
-    def _get_logits(self, hidden_states: torch.Tensor, lm_head: nn.Module,
+    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                     embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
         # Get the logits for the next tokens.
-        logits = lm_head(hidden_states)
+        logits = torch.matmul(hidden_states, embedding.t())
         if embedding_bias is not None:
             logits += embedding_bias
         logits = tensor_model_parallel_gather(logits)
@@ -96,6 +94,7 @@ def _apply_logits_processors(
                 and sampling_params.prompt_logprobs is not None):
             assert len(seq_ids) == 1
             logits_row_idx += sampling_metadata.prompt_lens[i] - 1
+
         if logits_processors:
             found_logits_processors = True
             for seq_id in seq_ids:
@@ -108,6 +107,6 @@ def _apply_logits_processors(
         else:
             logits_row_idx += len(seq_ids)
     if found_logits_processors:
-        # Ensure that no rows in logits were unexpectedly skipped.
+        # verifies that no rows in logits were missed unexpectedly
         assert logits_row_idx == logits.shape[0]
     return logits

+ 8 - 2
aphrodite/modeling/layers/ops/sample.py

@@ -1,5 +1,5 @@
 import math
-from typing import Tuple, Optional
+from typing import Optional, Tuple
 
 import torch
 import triton
@@ -15,6 +15,7 @@ MAX_TRITON_N_COLS = 131072
 
 def get_num_triton_sampler_splits(n_cols: int) -> int:
     """Get the number of splits to use for Triton sampling.
+
     Triton has a limit on the number of columns it can handle, so we need to
     split the tensor and call the kernel multiple times if it's too large.
     """
@@ -28,8 +29,8 @@ def _multi_split_sample(
     sampled_tokens_size: Tuple[int, int],
     sampled_logprobs_size: Tuple[int, int],
     sample_indices: torch.Tensor,
+    logprobs: torch.Tensor,
     *,
-    logprobs: Optional[torch.Tensor] = None,
     modify_greedy_probs: bool = False,
     save_logprobs: bool = False,
 ):
@@ -118,7 +119,9 @@ def sample(
     _save_modified_probs: bool = False,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
     """Sample tokens from probs. with per-sequence seeds.
+
     Can sample from a subset of sequences through sample_indices.
+
     Args:
         probs: Probabilities to sample from.
             shape = [batch_size, vocab_size]
@@ -143,6 +146,7 @@ def sample(
             (because we want to use the unmodified probs to pick the best
             split in case of multi-split sampling).
             This is exposed only for testing.
+
     Returns:
         sampled_tokens: shape = [n, max_best_of]
         sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
@@ -163,6 +167,7 @@ def sample(
         sampled_logprobs_size = (0, 0)
         logprobs = probs
 
+    assert logprobs is not None
     if _save_modified_probs:
         sampled_modified_probs_size = sampled_tokens_size
     else:
@@ -234,6 +239,7 @@ def _sample(probs: torch.Tensor,
             save_logprobs: bool = True,
             save_modified_probs: bool = False) -> torch.Tensor:
     """Sample tokens from probs.
+
     Args:
         probs [batch_size, vocab_size]: probs to sample from.
         logprobs [batch_size, vocab_size]: logprobs (used when

+ 16 - 21
aphrodite/modeling/layers/rotary_embedding.py

@@ -2,6 +2,8 @@
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
 # Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -32,14 +34,12 @@ from aphrodite._C import ops
 
 
 def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
-    """PyTorch-native implementation."""
     x1 = x[..., :x.shape[-1] // 2]
     x2 = x[..., x.shape[-1] // 2:]
     return torch.cat((-x2, x1), dim=-1)
 
 
 def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
-    """PyTorch-native implementation."""
     x1 = x[..., ::2]
     x2 = x[..., 1::2]
     x = torch.stack((-x2, x1), dim=-1)
@@ -111,7 +111,8 @@ class RotaryEmbedding(nn.Module):
             query_pass = query[..., self.rotary_dim:]
             key_pass = key[..., self.rotary_dim:]
 
-        self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
+        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
+            positions.device)
         cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
                                      if offsets is not None else positions]
         cos, sin = cos_sin.chunk(2, dim=-1)
@@ -145,11 +146,9 @@ class RotaryEmbedding(nn.Module):
         key: torch.Tensor,
         offsets: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        # ops.rotary_embedding() is an in-place operation that
-        # updates the query and key tensors.
         self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
-        # ops.rotary_embedding()/batched_rotary_embedding() are in-place ops
-        # that update the qk tensors
+        # ops.rotary_embedding()/batched_rotary_embedding()
+        # are in-place operations that update the query and key tensors.
         if offsets is not None:
             ops.batched_rotary_embedding(positions, query, key, self.head_size,
                                          self.cos_sin_cache,
@@ -252,11 +251,12 @@ def _yarn_find_correction_dim(num_rotations: int,
 
 
 # Find dim range bounds based on rotations
-def _yarn_find_correction_range(low_rot: int,
-                                high_rot: int,
-                                dim: int,
-                                base: float = 10000,
-                                max_position_embeddings: int = 2048) -> int:
+def _yarn_find_correction_range(
+        low_rot: int,
+        high_rot: int,
+        dim: int,
+        base: float = 10000,
+        max_position_embeddings: int = 2048) -> Tuple[int, int]:
     low = math.floor(
         _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
     high = math.ceil(
@@ -266,13 +266,11 @@ def _yarn_find_correction_range(low_rot: int,
 
 
 def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
-                           dtype: torch.dtype,
-                           device: torch.device) -> torch.Tensor:
+                           dtype: torch.dtype) -> torch.Tensor:
     if low == high:
         high += 0.001  # Prevent singularity
 
-    linear_func = (torch.arange(dim, dtype=dtype, device=device) -
-                   low) / (high - low)
+    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
     ramp_func = torch.clamp(linear_func, 0, 1)
     return ramp_func
 
@@ -300,8 +298,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         *,
         extrapolation_factor: float = 1,
         attn_factor: float = 1,
-        beta_fast: float = 32,
-        beta_slow: float = 1,
+        beta_fast: int = 32,
+        beta_slow: int = 1,
     ) -> None:
         self.scaling_factor = scaling_factor
         self.extrapolation_factor = extrapolation_factor
@@ -325,8 +323,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
                                                 self.rotary_dim, self.base,
                                                 self.max_position_embeddings)
         # Get n-d rotational scaling corrected for extrapolation
-        # FIXME: Add device here.
-        # pylint: disable=no-value-for-parameter
         inv_freq_mask = (1 - _yarn_linear_ramp_mask(
             low, high, self.rotary_dim // 2,
             dtype=torch.float)) * self.extrapolation_factor
@@ -379,7 +375,6 @@ def get_rope(
         elif scaling_type == "yarn":
             original_max_position = rope_scaling[
                 "original_max_position_embeddings"]
-            assert max_position == original_max_position * scaling_factor
             extra_kwargs = {
                 k: v
                 for k, v in rope_scaling.items()

+ 6 - 6
aphrodite/modeling/layers/sampler.py

@@ -64,12 +64,6 @@ class Sampler(nn.Module):
                                       sampling_tensors.freq_penalties,
                                       sampling_tensors.rep_penalties)
 
-        if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
-            logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                        sampling_tensors.dynatemp_mins,
-                                        sampling_tensors.dynatemp_maxs,
-                                        sampling_tensors.dynatemp_exps)
-
         if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
                 or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
             logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
@@ -93,6 +87,12 @@ class Sampler(nn.Module):
                 sampling_tensors.smoothing_factors,
                 sampling_tensors.smoothing_curves)
 
+        if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
+            logits = _apply_temperature(logits, sampling_tensors.temperatures,
+                                        sampling_tensors.dynatemp_mins,
+                                        sampling_tensors.dynatemp_maxs,
+                                        sampling_tensors.dynatemp_exps)
+
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         assert len(banned_tokens) == logits.shape[0]
         logits = _apply_token_bans(logits, banned_tokens)

+ 21 - 55
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -1,15 +1,12 @@
 from typing import Optional, Sequence
 
 import torch
+import torch.nn.functional as F
 from torch.nn.parameter import Parameter
 
-from aphrodite.distributed import (
-    divide,
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-    tensor_model_parallel_all_reduce,
-)
-from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
+from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.utils import set_weight_attrs
 
 DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -53,10 +50,8 @@ class VocabParallelEmbedding(torch.nn.Module):
                  num_embeddings: int,
                  embedding_dim: int,
                  params_dtype: Optional[torch.dtype] = None,
-                 linear_method=None,
                  org_num_embeddings: Optional[int] = None,
-                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
-                 is_input_emb: bool = True):
+                 padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
         super().__init__()
 
         # Keep the input dimensions.
@@ -75,46 +70,21 @@ class VocabParallelEmbedding(torch.nn.Module):
                 self.tp_size))
         self.num_embeddings_per_partition = (self.vocab_end_index -
                                              self.vocab_start_index)
-        idx = 0 if is_input_emb else 1
-        if linear_method is None or not linear_method.quant_config.quant_vocab(
-        )[idx]:
-            linear_method = UnquantizedLinearMethod()
-        self.linear_method = linear_method
-        self.linear_weights = self.linear_method.create_weights(
-            self.embedding_dim, [self.num_embeddings_per_partition],
-            self.embedding_dim, self.num_embeddings_padded, params_dtype)
-        for name, weight in self.linear_weights.items():
-            if isinstance(weight, torch.nn.parameter.Parameter):
-                self.register_parameter(name, weight)
-                set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+        self.weight = Parameter(
+            torch.empty(self.num_embeddings_per_partition,
+                        self.embedding_dim,
+                        dtype=params_dtype))
+        set_weight_attrs(self.weight, {
+            "parallel_dim": 0,
+            "weight_loader": self.weight_loader
+        })
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
-        output_dim = getattr(param, "output_dim", None)
-        packed_dim = getattr(param, "packed_dim", None)
-        if output_dim is not None:
-            shard_offset = self.vocab_start_index
-            shard_size = min(self.vocab_end_index,
-                             self.org_vocab_size) - shard_offset
-            if packed_dim == output_dim:
-                shard_size = shard_size // param.pack_factor
-                shard_offset = shard_offset // param.pack_factor
-            loaded_weight = loaded_weight.narrow(output_dim, shard_offset,
-                                                 shard_size)
-        if isinstance(param, torch.nn.parameter.UninitializedParameter):
-            vocab_shape = list(loaded_weight.shape)
-            if output_dim is not None:
-                if packed_dim == output_dim:
-                    vocab_shape[output_dim] = (
-                        self.num_embeddings_per_partition // param.pack_factor)
-                else:
-                    vocab_shape[output_dim] = self.num_embeddings_per_partition
-            param.materialize(vocab_shape, dtype=loaded_weight.dtype)
-        if output_dim is not None:
-            param.data.narrow(
-                output_dim, 0,
-                loaded_weight.shape[output_dim]).copy_(loaded_weight)
-        else:
-            param.data.copy_(loaded_weight)
+        parallel_dim = param.parallel_dim
+        assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
+        loaded_weight = loaded_weight[self.vocab_start_index:self.
+                                      vocab_end_index]
+        param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
 
     def forward(self, input_):
         if self.tp_size > 1:
@@ -127,9 +97,7 @@ class VocabParallelEmbedding(torch.nn.Module):
         else:
             masked_input = input_
             # Get the embeddings.
-        output_parallel = self.linear_method.apply_embedding(
-            self.linear_weights, masked_input)
-        # output_parallel = F.embedding(masked_input, self.weight)
+        output_parallel = F.embedding(masked_input, self.weight)
         # Mask the output embedding.
         if self.tp_size > 1:
             output_parallel[input_mask, :] = 0.0
@@ -159,18 +127,16 @@ class ParallelLMHead(VocabParallelEmbedding):
                  embedding_dim: int,
                  bias: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
-                 linear_method=None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
         super().__init__(num_embeddings, embedding_dim, params_dtype,
-                         linear_method, org_num_embeddings, padding_size,
-                         False)
+                         org_num_embeddings, padding_size)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,
                             dtype=params_dtype))
             set_weight_attrs(self.bias, {
-                "output_dim": 0,
+                "parallel_dim": 0,
                 "weight_loader": self.weight_loader
             })
         else:

+ 0 - 150
aphrodite/modeling/loader.py

@@ -1,150 +0,0 @@
-"""Utilities for selecting and loading models."""
-import contextlib
-import gc
-from contextlib import nullcontext
-from typing import Type
-from loguru import logger
-
-import torch
-import torch.nn as nn
-
-from aphrodite.common.config import DeviceConfig, ModelConfig
-from aphrodite.modeling.models import ModelRegistry
-from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
-from aphrodite.modeling.hf_downloader import (
-    get_quant_config,
-    initialize_dummy_weights,
-)
-from aphrodite.quantization.bitsandbytes import (
-    BNBLinearMethod,
-    replace_quant_params,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
-
-_VISION_MODEL_CLASSES = [
-    LlavaForConditionalGeneration,
-]
-
-
-@contextlib.contextmanager
-def _set_default_torch_dtype(dtype: torch.dtype):
-    """Sets the default torch dtype to the given dtype."""
-    old_dtype = torch.get_default_dtype()
-    torch.set_default_dtype(dtype)
-    yield
-    torch.set_default_dtype(old_dtype)
-
-
-def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
-    architectures = getattr(model_config.hf_config, "architectures", [])
-
-    for arch in architectures:
-        model_cls = ModelRegistry.load_model_cls(arch)
-        if model_cls is not None:
-            return model_cls
-    raise ValueError(
-        f"Model architectures {architectures} are not supported for now. "
-        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
-
-
-def get_model(model_config: ModelConfig, device_config: DeviceConfig,
-              **kwargs) -> nn.Module:
-    lora_config = kwargs.get("lora_config", None)
-    vision_language_config = kwargs.get("vision_language_config", None)
-    model_class = _get_model_architecture(model_config)
-
-    # Get the (maybe quantized) linear method.
-    linear_method = None
-    if model_config.quantization is not None:
-        quant_config = get_quant_config(model_config)
-        capability = torch.cuda.get_device_capability()
-        capability = capability[0] * 10 + capability[1]
-        if capability < quant_config.get_min_capability():
-            raise ValueError(
-                f"The quantization method {model_config.quantization} is not "
-                "supported for the current GPU. "
-                f"Minimum capability: {quant_config.get_min_capability()}. "
-                f"Current capability: {capability}.")
-        supported_dtypes = quant_config.get_supported_act_dtypes()
-        if model_config.dtype not in supported_dtypes:
-            # set the dtype to float16 for quantized models
-            model_config.dtype = torch.float16
-            logger.warning("Model is quantized. Forcing float16 datatype.")
-        linear_method = quant_config.get_linear_method()
-
-    with _set_default_torch_dtype(model_config.dtype):
-        # Create a model instance.
-        # The weights will be initialized as empty tensors.
-        with torch.device(device_config.device) if not (
-                isinstance(linear_method, BNBLinearMethod)
-                and linear_method.quant_config.from_float) else nullcontext():
-            if hasattr(model_class, "supported_lora_modules"):
-                model = model_class(model_config.hf_config, linear_method,
-                                    lora_config)
-            elif lora_config:
-                raise ValueError(
-                    f"Model {model_class.__name__} does not support LoRA, "
-                    "but LoRA is enabled. Support for this model may "
-                    "be added in the future. If this is important to you, "
-                    "please open an issue on github.")
-            else:
-                if model_class not in _VISION_MODEL_CLASSES:
-                    model = model_class(model_config.hf_config, linear_method)
-                else:
-                    model = model_class(model_config.hf_config,
-                                        vision_language_config, linear_method)
-        if model_config.load_format == "dummy":
-            # NOTE: For accurate performance evaluation, we assign
-            # random values to the weights.
-            initialize_dummy_weights(model)
-        else:
-            # Load the weights from the cached or downloaded files.
-            model.load_weights(model_config.model, model_config.download_dir,
-                               model_config.load_format, model_config.revision)
-
-        if isinstance(linear_method, BNBLinearMethod):
-            replace_quant_params(
-                model,
-                quant_config=linear_method.quant_config,
-                modules_to_not_convert="lm_head",
-            )
-            torch.cuda.synchronize()
-            if linear_method.quant_config.from_float:
-                model = model.cuda()
-            gc.collect()
-            torch.cuda.empty_cache()
-            tp = get_tensor_model_parallel_world_size()
-            logger.info(
-                "Memory allocated for converted model: {} GiB x {} = {} "
-                "GiB".format(
-                    round(
-                        torch.cuda.memory_allocated(
-                            torch.cuda.current_device()) /
-                        (1024 * 1024 * 1024),
-                        2,
-                    ),
-                    tp,
-                    round(
-                        torch.cuda.memory_allocated(
-                            torch.cuda.current_device()) * tp /
-                        (1024 * 1024 * 1024),
-                        2,
-                    ),
-                ))
-            logger.info(
-                "Memory reserved for converted model: {} GiB x {} = {} "
-                "GiB".format(
-                    round(
-                        torch.cuda.memory_reserved(torch.cuda.current_device())
-                        / (1024 * 1024 * 1024),
-                        2,
-                    ),
-                    tp,
-                    round(
-                        torch.cuda.memory_reserved(torch.cuda.current_device())
-                        * tp / (1024 * 1024 * 1024),
-                        2,
-                    ),
-                ))
-    return model.eval()

+ 31 - 0
aphrodite/modeling/model_loader/__init__.py

@@ -0,0 +1,31 @@
+from typing import Optional
+
+from torch import nn
+
+from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
+                                     ModelConfig, ParallelConfig,
+                                     SchedulerConfig, VisionLanguageConfig)
+from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
+                                                    get_model_loader)
+from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
+                                                   get_model_architecture)
+
+
+def get_model(
+        *, model_config: ModelConfig, load_config: LoadConfig,
+        device_config: DeviceConfig, parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+    loader = get_model_loader(load_config)
+    return loader.load_model(model_config=model_config,
+                             device_config=device_config,
+                             lora_config=lora_config,
+                             vision_language_config=vision_language_config,
+                             parallel_config=parallel_config,
+                             scheduler_config=scheduler_config)
+
+
+__all__ = [
+    "get_model", "get_model_loader", "BaseModelLoader",
+    "get_architecture_class_name", "get_model_architecture"
+]

+ 383 - 0
aphrodite/modeling/model_loader/loader.py

@@ -0,0 +1,383 @@
+# ruff: noqa: SIM117
+import copy
+import glob
+import os
+from abc import ABC, abstractmethod
+import gc
+from contextlib import nullcontext
+from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
+                    Type)
+
+import torch
+from torch import nn
+
+from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, DeviceConfig,
+                                     LoadConfig, LoadFormat, LoRAConfig,
+                                     ModelConfig, ParallelConfig,
+                                     SchedulerConfig, VisionLanguageConfig)
+from aphrodite.modeling.model_loader.tensorizer import (
+    TensorizerConfig, is_aphrodite_serialized_tensorizer, load_with_tensorizer,
+    tensorizer_weights_iterator)
+from aphrodite.modeling.model_loader.utils import (get_model_architecture,
+                                                   set_default_torch_dtype)
+from aphrodite.modeling.model_loader.weight_utils import (
+    download_weights_from_hf, filter_files_not_needed_for_inference,
+    get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
+    pt_weights_iterator, safetensors_weights_iterator)
+from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
+from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
+                                                 replace_quant_params)
+
+if TYPE_CHECKING:
+    from aphrodite.modeling.layers.linear import LinearMethodBase
+
+_VISION_MODEL_CLASSES = [
+    LlavaForConditionalGeneration,
+]
+
+
+
+def _get_linear_method(
+        model_config: ModelConfig,
+        load_config: LoadConfig) -> Optional["LinearMethodBase"]:
+    """Get the (maybe quantized) linear method."""
+    linear_method = None
+    if model_config.quantization is not None:
+        quant_config = get_quant_config(model_config, load_config)
+        capability = torch.cuda.get_device_capability()
+        capability = capability[0] * 10 + capability[1]
+        if capability < quant_config.get_min_capability():
+            raise ValueError(
+                f"The quantization method {model_config.quantization} is not "
+                "supported for the current GPU. "
+                f"Minimum capability: {quant_config.get_min_capability()}. "
+                f"Current capability: {capability}.")
+        supported_dtypes = quant_config.get_supported_act_dtypes()
+        if model_config.dtype not in supported_dtypes:
+            raise ValueError(
+                f"{model_config.dtype} is not supported for quantization "
+                f"method {model_config.quantization}. Supported dtypes: "
+                f"{supported_dtypes}")
+
+        linear_method = quant_config.get_linear_method()
+    return linear_method
+
+
+def _get_model_initialization_kwargs(
+        model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig]
+) -> Dict[str, Any]:
+    """Get extra kwargs for model initialization."""
+    extra_kwargs = {}
+    if hasattr(model_class, "supported_lora_modules"):
+        extra_kwargs["lora_config"] = lora_config
+    elif lora_config:
+        raise ValueError(
+            f"Model {model_class.__name__} does not support LoRA, "
+            "but LoRA is enabled. Support for this model may "
+            "be added in the future. If this is important to you, "
+            "please open an issue on github.")
+    elif model_class in _VISION_MODEL_CLASSES:
+        extra_kwargs["vision_language_config"] = vision_language_config
+    return extra_kwargs
+
+
+def _initialize_model(
+        model_config: ModelConfig, load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+    """Initialize a model with the given configurations."""
+    model_class = get_model_architecture(model_config)[0]
+    linear_method = _get_linear_method(model_config, load_config)
+
+    return model_class(config=model_config.hf_config,
+                       linear_method=linear_method,
+                       **_get_model_initialization_kwargs(
+                           model_class, lora_config, vision_language_config))
+
+
+class BaseModelLoader(ABC):
+    """Base class for model loaders."""
+
+    def __init__(self, load_config: LoadConfig):
+        self.load_config = load_config
+
+    @abstractmethod
+    def load_model(self, *, model_config: ModelConfig,
+                   device_config: DeviceConfig,
+                   lora_config: Optional[LoRAConfig],
+                   vision_language_config: Optional[VisionLanguageConfig],
+                   parallel_config: ParallelConfig,
+                   scheduler_config: SchedulerConfig) -> nn.Module:
+        """Load a model with the given configurations."""
+        ...
+
+
+class DefaultModelLoader(BaseModelLoader):
+    """Model loader that can load different file types from disk."""
+
+    def __init__(self, load_config: LoadConfig):
+        super().__init__(load_config)
+        if load_config.model_loader_extra_config:
+            raise ValueError(f"Model loader extra config is not supported for "
+                             f"load format {load_config.load_format}")
+
+    def _maybe_download_from_modelscope(
+            self, model: str, revision: Optional[str]) -> Optional[str]:
+        """Download model from ModelScope hub if APHRODITE_USE_MODELSCOPE is
+        True.
+        
+        Returns the path to the downloaded model, or None if the model is not
+        downloaded from ModelScope."""
+        if APHRODITE_USE_MODELSCOPE:
+            # download model from ModelScope hub,
+            # lazy import so that modelscope is not required for normal use.
+            # pylint: disable=C.
+            from modelscope.hub.snapshot_download import snapshot_download
+
+            if not os.path.exists(model):
+                model_path = snapshot_download(
+                    model_id=model,
+                    cache_dir=self.load_config.download_dir,
+                    revision=revision)
+            else:
+                model_path = model
+            return model_path
+        return None
+
+    def _prepare_weights(self, model_name_or_path: str,
+                         revision: Optional[str],
+                         fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
+        """Prepare weights for the model.
+
+        If the model is not local, it will be downloaded."""
+        model_name_or_path = self._maybe_download_from_modelscope(
+            model_name_or_path, revision) or model_name_or_path
+
+        is_local = os.path.isdir(model_name_or_path)
+        load_format = self.load_config.load_format
+        use_safetensors = False
+        # Some quantized models use .pt files for storing the weights.
+        if load_format == LoadFormat.AUTO:
+            allow_patterns = ["*.safetensors", "*.bin"]
+        elif load_format == LoadFormat.SAFETENSORS:
+            use_safetensors = True
+            allow_patterns = ["*.safetensors"]
+        elif load_format == LoadFormat.PT:
+            allow_patterns = ["*.pt"]
+        elif load_format == LoadFormat.NPCACHE:
+            allow_patterns = ["*.bin"]
+        else:
+            raise ValueError(f"Unknown load_format: {load_format}")
+
+        if fall_back_to_pt:
+            allow_patterns += ["*.pt"]
+
+        if not is_local:
+            hf_folder = download_weights_from_hf(model_name_or_path,
+                                                 self.load_config.download_dir,
+                                                 allow_patterns, revision)
+        else:
+            hf_folder = model_name_or_path
+
+        hf_weights_files: List[str] = []
+        for pattern in allow_patterns:
+            hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
+            if len(hf_weights_files) > 0:
+                if pattern == "*.safetensors":
+                    use_safetensors = True
+                break
+
+        if not use_safetensors:
+            hf_weights_files = filter_files_not_needed_for_inference(
+                hf_weights_files)
+
+        if len(hf_weights_files) == 0:
+            raise RuntimeError(
+                f"Cannot find any model weights with `{model_name_or_path}`")
+
+        return hf_folder, hf_weights_files, use_safetensors
+
+    def _get_weights_iterator(
+        self, model_name_or_path: str, revision: Optional[str],
+        fall_back_to_pt: bool
+    ) -> Generator[Tuple[str, torch.Tensor], None, None]:
+        """Get an iterator for the model weights based on the load format."""
+        hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
+            model_name_or_path, revision, fall_back_to_pt)
+        if self.load_config.load_format == LoadFormat.NPCACHE:
+            # Currently np_cache only support *.bin checkpoints
+            assert use_safetensors is False
+            return np_cache_weights_iterator(model_name_or_path,
+                                             self.load_config.download_dir,
+                                             hf_folder, hf_weights_files)
+        if use_safetensors:
+            return safetensors_weights_iterator(hf_weights_files)
+        return pt_weights_iterator(hf_weights_files)
+
+    def load_model(self, *, model_config: ModelConfig,
+                   device_config: DeviceConfig,
+                   lora_config: Optional[LoRAConfig],
+                   vision_language_config: Optional[VisionLanguageConfig],
+                   parallel_config: ParallelConfig,
+                   scheduler_config: SchedulerConfig) -> nn.Module:
+        with set_default_torch_dtype(model_config.dtype):
+            linear_method = _get_linear_method(model_config, self.load_config)
+
+            context = torch.device(device_config.device) if not (
+                isinstance(linear_method, BNBLinearMethod)
+                and linear_method.quant_config.from_float) else nullcontext()
+
+            with context:
+                model = _initialize_model(model_config, self.load_config,
+                                          lora_config, vision_language_config)
+            model.load_weights(
+                self._get_weights_iterator(model_config.model,
+                                           model_config.revision,
+                                           fall_back_to_pt=getattr(
+                                               model,
+                                               "fall_back_to_pt_during_load",
+                                               True)), )
+            for _, module in model.named_modules():
+                linear_method = getattr(module, "linear_method", None)
+                if linear_method is not None:
+                    linear_method.process_weights_after_loading(module)
+                if hasattr(module, "process_weights_after_loading"):
+                    module.process_weights_after_loading()
+        
+            if isinstance(linear_method, BNBLinearMethod):
+                replace_quant_params(
+                    model,
+                    quant_config=linear_method.quant_config,
+                    modules_to_not_convert="lm_head",
+                )
+                torch.cuda.synchronize()
+                if linear_method.quant_config.from_float:
+                    model = model.cuda()
+                gc.collect()
+                torch.cuda.empty_cache()
+
+        return model.eval()
+
+
+class DummyModelLoader(BaseModelLoader):
+    """Model loader that will set model weights to random values."""
+
+    def __init__(self, load_config: LoadConfig):
+        super().__init__(load_config)
+        if load_config.model_loader_extra_config:
+            raise ValueError(f"Model loader extra config is not supported for "
+                             f"load format {load_config.load_format}")
+
+    def load_model(self, *, model_config: ModelConfig,
+                   device_config: DeviceConfig,
+                   lora_config: Optional[LoRAConfig],
+                   vision_language_config: Optional[VisionLanguageConfig],
+                   parallel_config: ParallelConfig,
+                   scheduler_config: SchedulerConfig) -> nn.Module:
+        with set_default_torch_dtype(model_config.dtype):
+            with torch.device(device_config.device):
+                model = _initialize_model(model_config, self.load_config,
+                                          lora_config, vision_language_config)
+            # NOTE(woosuk): For accurate performance evaluation, we assign
+            # random values to the weights.
+            initialize_dummy_weights(model)
+        return model.eval()
+
+
+class TensorizerLoader(BaseModelLoader):
+    """Model loader using CoreWeave's tensorizer library."""
+
+    def __init__(self, load_config: LoadConfig):
+        super().__init__(load_config)
+        if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
+            self.tensorizer_config = load_config.model_loader_extra_config
+        else:
+            self.tensorizer_config = TensorizerConfig(
+                **load_config.model_loader_extra_config)
+
+    def _verify_config(self, model_config: ModelConfig,
+                       parallel_config: ParallelConfig):
+        self.tensorizer_config.verify_with_model_config(model_config)
+        self.tensorizer_config.verify_with_parallel_config(parallel_config)
+
+    def _get_weights_iterator(
+            self) -> Generator[Tuple[str, torch.Tensor], None, None]:
+        tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
+        return tensorizer_weights_iterator(tensorizer_args)
+
+    def _load_model_unserialized(
+            self, model_config: ModelConfig, device_config: DeviceConfig,
+            lora_config: Optional[LoRAConfig],
+            vision_language_config: Optional[VisionLanguageConfig]
+    ) -> nn.Module:
+        """Load an unserialized model with tensorizer.
+
+        Unserialized here means "not serialized with tensorizer". This
+        should still be faster than default HuggingFace loading, but will
+        be slower than loading a tensorizer-serialized model.
+        """
+        with set_default_torch_dtype(model_config.dtype):
+            with torch.device(device_config.device):
+                model = _initialize_model(model_config, self.load_config,
+                                          lora_config, vision_language_config)
+
+            model.load_weights(self._get_weights_iterator())
+        return model.eval()
+
+    def _load_model_serialized(
+            self, model_config: ModelConfig, device_config: DeviceConfig,
+            lora_config: Optional[LoRAConfig],
+            vision_language_config: Optional[VisionLanguageConfig]
+    ) -> nn.Module:
+        """Load a serialized model with tensorizer.
+
+        See the examples/tensorize_aphrodite_model.py example "
+        script for serializing Aphrodite models."""
+        with set_default_torch_dtype(model_config.dtype):
+            with torch.device(device_config.device):
+                model_class = get_model_architecture(model_config)[0]
+                linear_method = _get_linear_method(model_config,
+                                                   self.load_config)
+                extra_kwargs = _get_model_initialization_kwargs(
+                    model_class, lora_config, vision_language_config)
+                extra_kwargs["linear_method"] = linear_method
+
+                tensorizer_config = copy.copy(self.tensorizer_config)
+                tensorizer_config.model_class = model_class
+                tensorizer_config.hf_config = model_config.hf_config
+                tensorizer_config.dtype = model_config.dtype
+
+                model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
+        return model.eval()
+
+    def load_model(self, *, model_config: ModelConfig,
+                   device_config: DeviceConfig,
+                   lora_config: Optional[LoRAConfig],
+                   vision_language_config: Optional[VisionLanguageConfig],
+                   parallel_config: ParallelConfig,
+                   scheduler_config: SchedulerConfig) -> nn.Module:
+        self._verify_config(model_config, parallel_config)
+
+        if is_aphrodite_serialized_tensorizer(self.tensorizer_config):
+            return self._load_model_serialized(model_config, device_config,
+                                               lora_config,
+                                               vision_language_config)
+        return self._load_model_unserialized(model_config, device_config,
+                                             lora_config,
+                                             vision_language_config)
+
+
+def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
+    """Get a model loader based on the load format."""
+
+    if isinstance(load_config.load_format, type):
+        return load_config.load_format(load_config)
+
+    if load_config.load_format == LoadFormat.DUMMY:
+        return DummyModelLoader(load_config)
+
+    if load_config.load_format == LoadFormat.TENSORIZER:
+        return TensorizerLoader(load_config)
+
+    return DefaultModelLoader(load_config)

+ 15 - 13
aphrodite/modeling/neuron_loader.py → aphrodite/modeling/model_loader/neuron.py

@@ -1,18 +1,19 @@
 """Utilities for selecting and loading neuron models."""
 import importlib
 import os
-from typing import Optional, Type
+from typing import Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
 import transformers
 from transformers import PretrainedConfig
 
-from aphrodite.common.config import ModelConfig, ParallelConfig, SchedulerConfig
+from aphrodite.common.config import (ModelConfig, ParallelConfig,
+                                     SchedulerConfig)
+from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.common.sequence import SamplerOutput
 
 TORCH_DTYPE_TO_NEURON_AMP = {
     "auto": "f32",
@@ -27,7 +28,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
 }
 
 # Models supported by Neuron.
-_NEURON_SUPPORTED_MODELS = {
+_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
     "LlamaForCausalLM": ("transformers_neuronx.llama.model",
                          "LlamaForSampling", "LlamaForCausalLM"),
     "MistralForCausalLM": ("transformers_neuronx.mistral.model",
@@ -43,11 +44,13 @@ class NeuronCasualLM(nn.Module):
     ) -> None:
         super().__init__()
         self.config = config
-        self.model = None
         self.logits_processor = LogitsProcessor(config.vocab_size,
                                                 logits_as_input=True)
         self.sampler = Sampler()
 
+        # Lazy initialized
+        self.model: nn.Module
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -74,17 +77,17 @@ class NeuronCasualLM(nn.Module):
 
     def load_weights(self, model_name_or_path: str, **kwargs):
         arch = _get_model_architecture(self.config)
-        neuronx_module_path, neuronx_model_cls, hf_model_cls = (
+        neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
             _NEURON_SUPPORTED_MODELS[arch])
         neuronx_module = importlib.import_module(neuronx_module_path)
-        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
+        neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
 
         split_model_dir = f"{model_name_or_path}-split"
         if os.path.isdir(os.path.join(model_name_or_path,
                                       "pytorch_model.bin")):
             split_model_dir = model_name_or_path
         elif not os.path.exists(f"{model_name_or_path}-split"):
-            hf_model_cls = getattr(transformers, hf_model_cls)
+            hf_model_cls = getattr(transformers, hf_model_cls_name)
             from transformers_neuronx.module import save_pretrained_split
 
             hf_model = hf_model_cls.from_pretrained(model_name_or_path,
@@ -96,7 +99,7 @@ class NeuronCasualLM(nn.Module):
         self.model.to_neuron()
 
 
-def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+def _get_model_architecture(config: PretrainedConfig) -> str:
     architectures = getattr(config, "architectures", [])
     for arch in architectures:
         if arch in _NEURON_SUPPORTED_MODELS:
@@ -110,8 +113,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 def get_neuron_model(model_config: ModelConfig,
                      parallel_config: ParallelConfig,
                      scheduler_config: SchedulerConfig) -> nn.Module:
-    from transformers_neuronx.config import (NeuronConfig,
-                                             ContinuousBatchingConfig)
+    from transformers_neuronx.config import (ContinuousBatchingConfig,
+                                             NeuronConfig)
 
     # Create a model instance.
     model = NeuronCasualLM(model_config.hf_config)
@@ -129,7 +132,6 @@ def get_neuron_model(model_config: ModelConfig,
         neuron_config=neuron_config,
         context_length_estimate=[scheduler_config.max_model_len],
         n_positions=[scheduler_config.max_model_len],
-        batch_size=scheduler_config.max_num_seqs,
-    )
+        batch_size=scheduler_config.max_num_seqs)
 
     return model.eval()

+ 363 - 0
aphrodite/modeling/model_loader/tensorizer.py

@@ -0,0 +1,363 @@
+import argparse
+import dataclasses
+import io
+import os
+import time
+import typing
+from dataclasses import dataclass
+from typing import Generator, Optional, Tuple, Type, Union
+
+import torch
+from loguru import logger
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.common.config import ModelConfig, ParallelConfig
+from aphrodite.modeling.layers.linear import LinearMethodBase
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+
+tensorizer_load_fail = None
+
+try:
+    from tensorizer import (DecryptionParams, EncryptionParams,
+                            TensorDeserializer, TensorSerializer)
+    from tensorizer.stream_io import open_stream
+    from tensorizer.utils import (convert_bytes, get_mem_usage,
+                                  no_init_or_tensor)
+except ImportError as e:
+    tensorizer_load_fail = e
+
+__all__ = [
+    'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
+    'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
+    'no_init_or_tensor', 'TensorizerConfig'
+]
+
+
+
+@dataclass
+class TensorizerConfig:
+    tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+                          str, bytes, os.PathLike, int]
+    aphrodite_tensorized: bool
+    verify_hash: Optional[bool] = False
+    num_readers: Optional[int] = 1
+    encryption_keyfile: Optional[str] = None
+    s3_access_key_id: Optional[str] = None
+    s3_secret_access_key: Optional[str] = None
+    s3_endpoint: Optional[str] = None
+    model_class: Optional[Type[torch.nn.Module]] = None
+    hf_config: Optional[PretrainedConfig] = None
+    dtype: Optional[Union[str, torch.dtype]] = None
+
+    def _construct_tensorizer_args(self) -> "TensorizerArgs":
+        tensorizer_args = {
+            "tensorizer_uri": self.tensorizer_uri,
+            "aphrodite_tensorized": self.aphrodite_tensorized,
+            "verify_hash": self.verify_hash,
+            "num_readers": self.num_readers,
+            "encryption_keyfile": self.encryption_keyfile,
+            "s3_access_key_id": self.s3_access_key_id,
+            "s3_secret_access_key": self.s3_secret_access_key,
+            "s3_endpoint": self.s3_endpoint,
+        }
+        return TensorizerArgs(**tensorizer_args)
+
+    def verify_with_parallel_config(
+        self,
+        parallel_config: "ParallelConfig",
+    ) -> None:
+        if (parallel_config.tensor_parallel_size > 1
+                and self.tensorizer_uri is not None):
+            raise ValueError(
+                "Loading to multiple GPUs is not currently supported with "
+                "aphrodite-serialized models. Please set tensor_parallel_size=1."
+                " or use a non-aphrodite-serialized model, such as a "
+                "serialized Hugging Face `PretrainedModel`.")
+
+    def verify_with_model_config(self, model_config: "ModelConfig") -> None:
+        if (model_config.quantization is not None
+                and self.tensorizer_uri is not None):
+            logger.warning(
+                "Loading a model using Tensorizer with quantization on aphrodite"
+                " is unstable and may lead to errors.")
+
+
+def load_with_tensorizer(tensorizer_config: TensorizerConfig,
+                         **extra_kwargs) -> nn.Module:
+    tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
+    return tensorizer.deserialize()
+
+
+def is_aphrodite_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
+    if tensorizer_config is None:
+        return False
+    return tensorizer_config.aphrodite_tensorized
+
+
+@dataclass
+class TensorizerArgs:
+    tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+                          str, bytes, os.PathLike, int]
+    aphrodite_tensorized: bool
+    verify_hash: Optional[bool] = False
+    num_readers: Optional[int] = 1
+    encryption_keyfile: Optional[str] = None
+    s3_access_key_id: Optional[str] = None
+    s3_secret_access_key: Optional[str] = None
+    s3_endpoint: Optional[str] = None
+    """
+  Args for the TensorizerAgent class. These are used to configure the behavior 
+  of the TensorDeserializer when loading tensors from a serialized model.
+  
+  Args:
+      tensorizer_uri: Path to serialized model tensors. Can be a local file 
+          path or a S3 URI.
+      aphrodite_tensorized: If True, indicates that the serialized model is a 
+          aphrodite model. This is used to determine the behavior of the 
+          TensorDeserializer when loading tensors from a serialized model.
+          It is far faster to deserialize a aphrodite model as it utilizes
+          tensorizer's optimized GPU loading.
+      verify_hash: If True, the hashes of each tensor will be verified against 
+          the hashes stored in the metadata. A `HashMismatchError` will be 
+          raised if any of the hashes do not match.
+      num_readers: Controls how many threads are allowed to read concurrently
+          from the source file. Default is 1. This greatly increases
+          performance.
+      encryption_keyfile: File path to a binary file containing a  
+          binary key to use for decryption. `None` (the default) means 
+          no decryption. See the example script in 
+          examples/tensorize_aphrodite_model.py. 
+      s3_access_key_id: The access key for the S3 bucket. Can also be set via
+          the S3_ACCESS_KEY_ID environment variable.
+      s3_secret_access_key: The secret access key for the S3 bucket. Can also
+          be set via the S3_SECRET_ACCESS_KEY environment variable.
+      s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
+          S3_ENDPOINT_URL environment variable.
+  """
+
+    def __post_init__(self):
+        self.file_obj = self.tensorizer_uri
+        self.s3_access_key_id = (self.s3_access_key_id
+                                 or os.environ.get("S3_ACCESS_KEY_ID")) or None
+        self.s3_secret_access_key = (
+            self.s3_secret_access_key
+            or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
+        self.s3_endpoint = (self.s3_endpoint
+                            or os.environ.get("S3_ENDPOINT_URL")) or None
+        self.stream_params = {
+            "s3_access_key_id": self.s3_access_key_id,
+            "s3_secret_access_key": self.s3_secret_access_key,
+            "s3_endpoint": self.s3_endpoint,
+        }
+
+        self.deserializer_params = {
+            "verify_hash": self.verify_hash,
+            "encryption": self.encryption_keyfile,
+            "num_readers": self.num_readers
+        }
+        if self.encryption_keyfile:
+            with open_stream(
+                    self.encryption_keyfile,
+                    **self.stream_params,
+            ) as stream:
+                key = stream.read()
+                decryption_params = DecryptionParams.from_key(key)
+                self.deserializer_params['encryption'] = decryption_params
+
+    @staticmethod
+    def add_cli_args(
+            parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+        """Tensorizer CLI arguments"""
+
+        # Tensorizer options arg group
+        group = parser.add_argument_group(
+            'tensorizer options',
+            description=('Options for configuring the behavior of the'
+                         ' tensorizer deserializer when '
+                         '--load-format=tensorizer'))
+
+        group.add_argument(
+            "--tensorizer-uri",
+            help="Path to serialized model tensors. Can be a local file path,"
+            " or an HTTP(S) or S3 URI.",
+        )
+        group.add_argument(
+            "--verify-hash",
+            action="store_true",
+            help="If enabled, the hashes of each tensor will be verified"
+            " against the hashes stored in the file metadata. An exception"
+            " will be raised if any of the hashes do not match.",
+        )
+        group.add_argument(
+            "--encryption-keyfile",
+            default=None,
+            help="The file path to a binary file containing a binary key to "
+            "use for decryption. Can be a file path or S3 network URI.")
+        group.add_argument(
+            "--num-readers",
+            default=1,
+            type=int,
+            help="Controls how many threads are allowed to read concurrently "
+            "from the source file.")
+        group.add_argument(
+            "--s3-access-key-id",
+            default=None,
+            help="The access key for the S3 bucket. Can also be set via the "
+            "S3_ACCESS_KEY_ID environment variable.",
+        )
+        group.add_argument(
+            "--s3-secret-access-key",
+            default=None,
+            help="The secret access key for the S3 bucket. Can also be set via "
+            "the S3_SECRET_ACCESS_KEY environment variable.",
+        )
+        group.add_argument(
+            "--s3-endpoint",
+            default=None,
+            help="The endpoint for the S3 bucket. Can also be set via the "
+            "S3_ENDPOINT_URL environment variable.",
+        )
+        group.add_argument(
+            "--aphrodite-tensorized",
+            action="store_true",
+            help="If enabled, indicates that the serialized model is a aphrodite "
+            "model. This is used to determine the behavior of the "
+            "TensorDeserializer when loading tensors from a "
+            "serialized model.")
+
+        return parser
+
+    @classmethod
+    def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
+        attrs = [attr.name for attr in dataclasses.fields(cls)]
+        tensorizer_args = cls(**{
+            attr: getattr(args, attr)
+            for attr in attrs if hasattr(args, attr)
+        })
+        return tensorizer_args
+
+
+class TensorizerAgent:
+    """
+    A class for performing tensorizer deserializations specifically for
+    aphrodite models using plaid_mode. Uses TensorizerArgs to configure the
+    behavior of the TensorDeserializer when loading tensors from a serialized
+    model. For deserializations of HuggingFace models, TensorDeserializer is
+    instead used as an iterator directly in the func hf_model_weights_iterator
+    in aphrodite/modeling/model_loader/weight_utils.py
+    """
+
+    def __init__(self, tensorizer_config: TensorizerConfig,
+                 linear_method: LinearMethodBase, **extra_kwargs):
+        if tensorizer_load_fail is not None:
+            raise ImportError(
+                "Tensorizer is not installed. Please install tensorizer "
+                "to use this feature with `pip install aphrodite-engine[tensorizer]`."
+            ) from tensorizer_load_fail
+
+        self.tensorizer_config = tensorizer_config
+        self.tensorizer_args = (
+            self.tensorizer_config._construct_tensorizer_args())
+        self.extra_kwargs = extra_kwargs
+        if extra_kwargs.get("linear_method", None) is not None:
+            self.linear_method = extra_kwargs["linear_method"]
+        else:
+            self.linear_method = linear_method
+        self.model = self._init_model()
+
+    def _init_model(self):
+        model_args = self.tensorizer_config.hf_config
+        model_args.torch_dtype = self.tensorizer_config.dtype
+        with no_init_or_tensor():
+            return self.tensorizer_config.model_class(
+                config=model_args,
+                linear_method=self.linear_method,
+                **self.extra_kwargs)
+
+    def _resize_lora_embeddings(self):
+        """Modify LoRA embedding layers to use bigger tensors
+        to allow for adapter added tokens."""
+        for child in self.model.modules():
+            if (isinstance(child, VocabParallelEmbedding)
+                    and child.weight.shape[0] <
+                    child.num_embeddings_per_partition):
+                new_weight = torch.empty(child.num_embeddings_per_partition,
+                                         child.embedding_dim,
+                                         dtype=child.weight.dtype,
+                                         device=child.weight.device)
+                new_weight[:child.weight.shape[0]].copy_(child.weight.data)
+                new_weight[child.weight.shape[0]:].fill_(0)
+                child.weight.data = new_weight
+
+    def _check_tensors_on_meta_device(self):
+        for tensor in self.model.state_dict().values():
+            if tensor.device.type == 'meta':
+                raise ValueError(
+                    "The serialized model contains tensors on the meta device,"
+                    " indicating that some tensors were not loaded properly."
+                    " Please check that the parameters of the model being"
+                    " specified match that of the serialized model, such as"
+                    " its quantization.")
+
+    def deserialize(self):
+        """
+        Deserialize the model using the TensorDeserializer. This method is
+        specifically for Aphrodite models using tensorizer's plaid_mode.
+
+        The deserializer makes use of tensorizer_args.stream_params
+        to configure the behavior of the stream when loading tensors from a
+        serialized model. The deserializer_params are used to configure the
+        behavior of the TensorDeserializer when loading tensors themselves.
+        Documentation on these params can be found in TensorizerArgs
+
+        Returns:
+            nn.Module: The deserialized model.
+        """
+        before_mem = get_mem_usage()
+        start = time.perf_counter()
+        with open_stream(
+                self.tensorizer_args.tensorizer_uri,
+                mode="rb",
+                **self.tensorizer_args.stream_params,
+        ) as stream, TensorDeserializer(
+                stream,
+                dtype=self.tensorizer_config.dtype,
+                **self.tensorizer_args.deserializer_params) as deserializer:
+            deserializer.load_into_module(self.model)
+            end = time.perf_counter()
+
+        total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+        duration = end - start
+        per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+        after_mem = get_mem_usage()
+        deserializer.close()
+        logger.info(f"Deserialized {total_bytes_str} in "
+                    f"{end - start:0.2f}s, {per_second}/s")
+        logger.info(f"Memory usage before: {before_mem}")
+        logger.info(f"Memory usage after: {after_mem}")
+
+        self._check_tensors_on_meta_device()
+        self._resize_lora_embeddings()
+        return self.model.eval()
+
+
+def tensorizer_weights_iterator(
+    tensorizer_args: "TensorizerArgs"
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+    logger.warning(
+        "Deserializing HuggingFace models is not optimized for "
+        "loading on Aphrodite, as tensorizer is forced to load to CPU. "
+        "Consider deserializing a Aphrodite model instead for faster "
+        "load times. See the examples/tensorize_aphrodite_model.py example "
+        "script for serializing Aphrodite models.")
+
+    deserializer_args = tensorizer_args.deserializer_params
+    stream_params = tensorizer_args.stream_params
+    stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
+    with TensorDeserializer(stream, **deserializer_args,
+                            device="cpu") as state:
+        for name, param in state.items():
+            yield name, param
+    del state

+ 41 - 0
aphrodite/modeling/model_loader/utils.py

@@ -0,0 +1,41 @@
+"""Utilities for selecting and loading models."""
+import contextlib
+from typing import Tuple, Type
+
+import torch
+from torch import nn
+
+from aphrodite.common.config import ModelConfig
+from aphrodite.modeling.models import ModelRegistry
+
+
+@contextlib.contextmanager
+def set_default_torch_dtype(dtype: torch.dtype):
+    """Sets the default torch dtype to the given dtype."""
+    old_dtype = torch.get_default_dtype()
+    torch.set_default_dtype(dtype)
+    yield
+    torch.set_default_dtype(old_dtype)
+
+
+def get_model_architecture(
+        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
+    architectures = getattr(model_config.hf_config, "architectures", [])
+    # Special handling for quantized Mixtral.
+    # FIXME: This is a temporary hack.
+    if (model_config.quantization is not None
+            and model_config.quantization != "fp8"
+            and "MixtralForCausalLM" in architectures):
+        architectures = ["QuantMixtralForCausalLM"]
+
+    for arch in architectures:
+        model_cls = ModelRegistry.load_model_cls(arch)
+        if model_cls is not None:
+            return (model_cls, arch)
+    raise ValueError(
+        f"Model architectures {architectures} are not supported for now. "
+        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def get_architecture_class_name(model_config: ModelConfig) -> str:
+    return get_model_architecture(model_config)[1]

+ 362 - 0
aphrodite/modeling/model_loader/weight_utils.py

@@ -0,0 +1,362 @@
+"""Utilities for downloading and initializing model weights."""
+import fnmatch
+import glob
+import hashlib
+import json
+import os
+import tempfile
+from collections import defaultdict
+from typing import Any, Generator, Iterable, List, Optional, Tuple
+
+import filelock
+import huggingface_hub.constants
+import numpy as np
+import torch
+from huggingface_hub import HfFileSystem, snapshot_download
+from safetensors.torch import load_file, safe_open, save_file
+from tqdm.auto import tqdm
+from loguru import logger
+
+from aphrodite.common.config import LoadConfig, ModelConfig
+from aphrodite.quantization import (QuantizationConfig,
+                                                     get_quantization_config)
+from aphrodite.quantization.schema import QuantParamSchema
+
+
+# use system-level temp directory for file locks, so that multiple users
+# can share the same lock without error.
+# lock files in the temp directory will be automatically deleted when the
+# system reboots, so users will not complain about annoying lock files
+temp_dir = tempfile.gettempdir()
+
+
+def enable_hf_transfer():
+    """automatically activates hf_transfer
+    """
+    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
+        try:
+            # enable hf hub transfer if available
+            import hf_transfer  # type: ignore # noqa
+            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
+        except ImportError:
+            pass
+
+
+enable_hf_transfer()
+
+
+class DisabledTqdm(tqdm):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs, disable=True)
+
+
+def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
+    lock_dir = cache_dir or temp_dir
+    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
+    model_name = model_name_or_path.replace("/", "-")
+    hash_name = hashlib.sha256(model_name.encode()).hexdigest()
+    # add hash to avoid conflict with old users' lock files
+    lock_file_name = hash_name + model_name + ".lock"
+    # mode 0o666 is required for the filelock to be shared across users
+    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
+                             mode=0o666)
+    return lock
+
+
+def _shared_pointers(tensors):
+    ptrs = defaultdict(list)
+    for k, v in tensors.items():
+        ptrs[v.data_ptr()].append(k)
+    failing = []
+    for _, names in ptrs.items():
+        if len(names) > 1:
+            failing.append(names)
+    return failing
+
+
+def convert_bin_to_safetensor_file(
+    pt_filename: str,
+    sf_filename: str,
+) -> None:
+    loaded = torch.load(pt_filename, map_location="cpu")
+    if "state_dict" in loaded:
+        loaded = loaded["state_dict"]
+    shared = _shared_pointers(loaded)
+    for shared_weights in shared:
+        for name in shared_weights[1:]:
+            loaded.pop(name)
+
+    # For tensors to be contiguous
+    loaded = {k: v.contiguous() for k, v in loaded.items()}
+
+    dirname = os.path.dirname(sf_filename)
+    os.makedirs(dirname, exist_ok=True)
+    save_file(loaded, sf_filename, metadata={"format": "pt"})
+
+    # check file size
+    sf_size = os.stat(sf_filename).st_size
+    pt_size = os.stat(pt_filename).st_size
+    if (sf_size - pt_size) / pt_size > 0.01:
+        raise RuntimeError(f"""The file size different is more than 1%:
+         - {sf_filename}: {sf_size}
+         - {pt_filename}: {pt_size}
+         """)
+
+    # check if the tensors are the same
+    reloaded = load_file(sf_filename)
+    for k in loaded:
+        pt_tensor = loaded[k]
+        sf_tensor = reloaded[k]
+        if not torch.equal(pt_tensor, sf_tensor):
+            raise RuntimeError(f"The output tensors do not match for key {k}")
+
+
+# TODO: Move this to other place.
+def get_quant_config(model_config: ModelConfig,
+                     load_config: LoadConfig) -> QuantizationConfig:
+    quant_cls = get_quantization_config(model_config.quantization)
+    # Read the quantization config from the HF model config, if available.
+    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
+                              None)
+    if hf_quant_config is not None:
+        return quant_cls.from_config(hf_quant_config)
+    model_name_or_path = model_config.model
+    is_local = os.path.isdir(model_name_or_path)
+    if not is_local:
+        # Download the config files.
+        with get_lock(model_name_or_path, load_config.download_dir):
+            hf_folder = snapshot_download(model_name_or_path,
+                                          revision=model_config.revision,
+                                          allow_patterns="*.json",
+                                          cache_dir=load_config.download_dir,
+                                          tqdm_class=DisabledTqdm)
+    else:
+        hf_folder = model_name_or_path
+
+    possible_config_filenames = quant_cls.get_config_filenames()
+
+    # If the quantization config is not found, use the default config.
+    if not possible_config_filenames:
+        return quant_cls()
+
+    config_files = glob.glob(os.path.join(hf_folder, "*.json"))
+
+    quant_config_files = [
+        f for f in config_files if any(
+            f.endswith(x) for x in possible_config_filenames)
+    ]
+    if len(quant_config_files) == 0:
+        raise ValueError(
+            f"Cannot find the config file for {model_config.quantization}")
+    if len(quant_config_files) > 1:
+        raise ValueError(
+            f"Found multiple config files for {model_config.quantization}: "
+            f"{quant_config_files}")
+
+    quant_config_file = quant_config_files[0]
+    with open(quant_config_file, "r") as f:
+        config = json.load(f)
+    return quant_cls.from_config(config)
+
+
+def download_weights_from_hf(model_name_or_path: str,
+                             cache_dir: Optional[str],
+                             allow_patterns: List[str],
+                             revision: Optional[str] = None) -> str:
+    """Download model weights from Hugging Face Hub.
+    
+    Args:
+        model_name_or_path (str): The model name or path.
+        cache_dir (Optional[str]): The cache directory to store the model
+            weights. If None, will use HF defaults.
+        allow_patterns (List[str]): The allowed patterns for the
+            weight files. Files matched by any of the patterns will be
+            downloaded.
+        revision (Optional[str]): The revision of the model.
+
+    Returns:
+        str: The path to the downloaded model weights.
+    """
+    # Before we download we look at that is available:
+    fs = HfFileSystem()
+    file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
+
+    # depending on what is available we download different things
+    for pattern in allow_patterns:
+        matching = fnmatch.filter(file_list, pattern)
+        if len(matching) > 0:
+            allow_patterns = [pattern]
+            break
+
+    logger.info(f"Using model weights format {allow_patterns}")
+    # Use file lock to prevent multiple processes from
+    # downloading the same model weights at the same time.
+    with get_lock(model_name_or_path, cache_dir):
+        hf_folder = snapshot_download(model_name_or_path,
+                                      allow_patterns=allow_patterns,
+                                      cache_dir=cache_dir,
+                                      tqdm_class=DisabledTqdm,
+                                      revision=revision)
+    return hf_folder
+
+
+def filter_files_not_needed_for_inference(
+        hf_weights_files: List[str]) -> List[str]:
+    """
+    Exclude files that are not needed for inference.
+
+    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
+    """
+    blacklist = [
+        "training_args.bin",
+        "optimizer.bin",
+        "optimizer.pt",
+        "scheduler.pt",
+        "scaler.pt",
+    ]
+    hf_weights_files = [
+        f for f in hf_weights_files
+        if not any(f.endswith(x) for x in blacklist)
+    ]
+    return hf_weights_files
+
+
+def np_cache_weights_iterator(
+    model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
+    hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+    """Iterate over the weights in the model np files.
+
+    Will dump the model weights to numpy files if they are not already dumped.
+    """
+    # Convert the model weights from torch tensors to numpy arrays for
+    # faster loading.
+    np_folder = os.path.join(hf_folder, "np")
+    os.makedirs(np_folder, exist_ok=True)
+    weight_names_file = os.path.join(np_folder, "weight_names.json")
+    # Use file lock to prevent multiple processes from
+    # dumping the same model weights to numpy at the same time.
+    with get_lock(model_name_or_path, cache_dir):
+        if not os.path.exists(weight_names_file):
+            weight_names = []
+            for bin_file in hf_weights_files:
+                state = torch.load(bin_file, map_location="cpu")
+                for name, param in state.items():
+                    param_path = os.path.join(np_folder, name)
+                    with open(param_path, "wb") as f:
+                        np.save(f, param.cpu().detach().numpy())
+                    weight_names.append(name)
+            with open(weight_names_file, "w") as f:
+                json.dump(weight_names, f)
+
+    with open(weight_names_file, "r") as f:
+        weight_names = json.load(f)
+
+    for name in weight_names:
+        param_path = os.path.join(np_folder, name)
+        with open(param_path, "rb") as f:
+            param = np.load(f)
+        yield name, torch.from_numpy(param)
+
+
+def safetensors_weights_iterator(
+    hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+    """Iterate over the weights in the model safetensor files."""
+    for st_file in hf_weights_files:
+        with safe_open(st_file, framework="pt") as f:
+            for name in f.keys():  # noqa: SIM118
+                param = f.get_tensor(name)
+                yield name, param
+
+
+def pt_weights_iterator(
+    hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+    """Iterate over the weights in the model bin/pt files."""
+    for bin_file in hf_weights_files:
+        state = torch.load(bin_file, map_location="cpu")
+        for name, param in state.items():
+            yield name, param
+        del state
+        torch.cuda.empty_cache()
+
+
+def kv_cache_scales_loader(
+        filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
+        model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
+    """
+    A simple utility to read in KV cache scaling factors that have been
+    previously serialized to disk. Used by the model to populate the appropriate
+    KV cache scaling factors. The serialization should represent a dictionary
+    whose keys are the TP ranks and values are another dictionary mapping layers
+    to their KV cache scaling factors.
+    Keep this function in sync with the output of examples/fp8/extract_scales.py
+    """
+    try:
+        with open(filename) as f:
+            context = {
+                "model_type": model_type,
+                "num_hidden_layers": num_hidden_layers,
+                "tp_rank": tp_rank,
+                "tp_size": tp_size,
+            }
+            schema_dct = json.load(f)
+            schema = QuantParamSchema.model_validate(schema_dct,
+                                                     context=context)
+            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
+            return layer_scales_map.items()
+
+    except FileNotFoundError:
+        logger.error(f"File or directory '{filename}' not found.")
+    except json.JSONDecodeError:
+        logger.error(f"Error decoding JSON in file '{filename}'.")
+    except Exception as e:
+        logger.error(f"An error occurred while reading '{filename}': {e}")
+    # This section is reached if and only if any of the excepts are hit
+    # Return an empty iterable (list) => no KV cache scales are loaded
+    # which ultimately defaults to 1.0 scales
+    logger.warning("Defaulting to KV cache scaling factors = 1.0 "
+                   f"for all layers in TP rank {tp_rank} "
+                   "as an error occurred during loading.")
+    return []
+
+
+def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
+    """convert PySafeSlice object from safetensors to torch.Tensor
+
+    PySafeSlice object supports indexing, which is done before loading the
+    actual tensor and can reduce the amount of memory being read into the
+    memory. However, it does not support more advanced functionalities
+    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
+    tensor with these more complicated operators, we need to convert to
+    tensor first.
+    """
+    if not isinstance(x, torch.Tensor):
+        x = x[:]
+    return x
+
+
+def default_weight_loader(param: torch.Tensor,
+                          loaded_weight: torch.Tensor) -> None:
+    """Default weight loader."""
+    assert param.size() == loaded_weight.size()
+    param.data.copy_(loaded_weight)
+
+
+def initialize_dummy_weights(
+    model: torch.nn.Module,
+    low: float = -1e-3,
+    high: float = 1e-3,
+) -> None:
+    """Initialize model weights with random values.
+
+    The model weights must be randomly initialized for accurate performance
+    measurements. Additionally, the model weights should not cause NaNs in the
+    forward pass. We empirically found that initializing the weights with
+    values between -1e-3 and 1e-3 works well for most models.
+    """
+    for param in model.state_dict().values():
+        if torch.is_floating_point(param):
+            param.data.uniform_(low, high)

+ 10 - 5
aphrodite/modeling/models/__init__.py

@@ -1,8 +1,8 @@
 import importlib
 from typing import Dict, List, Optional, Type
-from loguru import logger
 
 import torch.nn as nn
+from loguru import logger
 
 from aphrodite.common.utils import is_hip
 
@@ -15,8 +15,8 @@ _MODELS = {
     "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
     "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
     "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
+    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
     "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
-    "CohereForCausalLM": ("cohere", "CohereForCausalLM"),
     "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
     "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
     "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
@@ -27,19 +27,22 @@ _MODELS = {
     "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
     "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
     "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
+    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
-    # For decapoda-research/llama-*
-    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
     "LlavaForConditionalGeneration":
     ("llava", "LlavaForConditionalGeneration"),
+    # For decapoda-research/llama-*
+    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
     "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
-    "YiForCausalLM": ("llama", "LlamaForCausalLM"),
+    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
     # transformers's mpt class has lower case
     "MptForCausalLM": ("mpt", "MPTForCausalLM"),
     "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
+    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
     "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
+    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
     "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
@@ -47,6 +50,8 @@ _MODELS = {
     "RWForCausalLM": ("falcon", "FalconForCausalLM"),
     "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
     "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
+    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
+    "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
 }
 
 # Architecture -> type.

+ 79 - 114
aphrodite/modeling/models/baichuan.py

@@ -18,41 +18,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only BaiChuan model compatible with HuggingFace weights."""
-
 import math
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
+from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-    ColumnParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-)
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
-from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
 
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -90,47 +80,21 @@ class BaiChuanMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.down_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -150,8 +114,8 @@ class BaiChuanAttention(nn.Module):
     ):
         super().__init__()
         self.hidden_size = hidden_size
-        tensor_model_parallel_world_size = (
-            get_tensor_model_parallel_world_size())
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
         self.total_num_heads = num_heads
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
         self.num_heads = (self.total_num_heads //
@@ -185,22 +149,16 @@ class BaiChuanAttention(nn.Module):
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
             scaling = self.head_dim**-0.5
-            self.attn = Attention(
-                self.num_heads,
-                self.head_dim,
-                scaling,
-                alibi_slopes=alibi_slopes,
-            )
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  scaling,
+                                  alibi_slopes=alibi_slopes)
         else:
-            is_neox_style = (True if linear_method is None
-                             or linear_method.quant_config.rope_style() is None
-                             else linear_method.quant_config.rope_style())
             self.rotary_emb = get_rope(
                 self.head_dim,
                 rotary_dim=self.head_dim,
                 max_position=self.max_position_embeddings,
                 base=self.rope_theta,
-                is_neox_style=is_neox_style,
             )
             self.scaling = self.head_dim**-0.5
             self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
@@ -223,12 +181,10 @@ class BaiChuanAttention(nn.Module):
 
 class BaiChuanDecoderLayer(nn.Module):
 
-    def __init__(
-        self,
-        config: BaiChuanConfig,
-        position_embedding: str,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
         rope_theta = getattr(config, "rope_theta", 10000)
@@ -284,20 +240,19 @@ class BaiChuanDecoderLayer(nn.Module):
 
 class BaiChuanModel(nn.Module):
 
-    def __init__(
-        self,
-        config: BaiChuanConfig,
-        position_embedding: str,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 position_embedding: str,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
 
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.layers = nn.ModuleList([
             BaiChuanDecoderLayer(config, position_embedding, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -327,22 +282,36 @@ class BaiChuanModel(nn.Module):
 
 
 class BaiChuanBaseForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "W_pack": ["W_pack"],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "W_pack",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+    ]
+    embedding_modules = {}
+    embedding_padding_modules = []
 
     def __init__(
         self,
         config,
         position_embedding: str,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.model = BaiChuanModel(config, position_embedding, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -358,7 +327,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -370,29 +339,19 @@ class BaiChuanBaseForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             if name == "lm_head.weight":
-                # Unlike Baichuan, Baichuan2 normalizes the head weights. Ref.:
+                # Unlike Baichuan, Baichuan2 normalizes the head weights.
+                # Refer to:
                 # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                 # Distinguish between Baichuan and Baichuan2 by checking the
                 # vocab size.
@@ -401,7 +360,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
                     loaded_weight = torch.nn.functional.normalize(
                         loaded_weight)
 
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
@@ -425,19 +384,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
 class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
     """Baichuan 13B and Baichuan2 7B/13B."""
 
-    def __init__(self,
-                 config,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ):
         if config.hidden_size == 4096:  # baichuan2 7b
-            super().__init__(config, "ROPE", linear_method)
+            super().__init__(config, "ROPE", linear_method, lora_config)
         else:  # baichuan 13b, baichuan2 13b
-            super().__init__(config, "ALIBI", linear_method)
+            super().__init__(config, "ALIBI", linear_method, lora_config)
 
 
 class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
     """Baichuan 7B."""
 
-    def __init__(self,
-                 config,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__(config, "ROPE", linear_method)
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ):
+        super().__init__(config, "ROPE", linear_method, lora_config)

+ 16 - 35
aphrodite/modeling/models/bloom.py

@@ -2,7 +2,7 @@
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
 # Copyright 2023 The PygmalionAI team.
-# Copyright 2023 The CacheFlow team.
+# Copyright 2023 The vLLM team.
 # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,13 +18,16 @@
 # limitations under the License.
 """Inference-only BLOOM model compatible with HuggingFace weights."""
 import math
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import BloomConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
@@ -32,14 +35,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -223,7 +222,9 @@ class BloomModel(nn.Module):
 
         # Embedding + LN Embedding
         self.word_embeddings = VocabParallelEmbedding(
-            config.vocab_size, self.embed_dim, linear_method=linear_method)
+            config.vocab_size,
+            self.embed_dim,
+        )
         self.word_embeddings_layernorm = nn.LayerNorm(
             self.embed_dim, eps=config.layer_norm_epsilon)
 
@@ -269,11 +270,7 @@ class BloomForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.transformer = BloomModel(config, linear_method)
         self.lm_head_weight = self.transformer.word_embeddings.weight
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -289,7 +286,7 @@ class BloomForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -301,31 +298,15 @@ class BloomForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
-            if "lm_head" in name and name not in params_dict:
+        for name, loaded_weight in weights:
+            if name == "lm_head.weight":
                 continue
             if not name.startswith("transformer."):
                 name = "transformer." + name
             param = params_dict[name]
 
-            if "word_embeddings" in name:
-                # Copy word embedding to lm_head
-                head_name = name.replace("transformer.word_embeddings",
-                                         "lm_head")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
-
             if "query_key_value" in name:
                 # NOTE: BLOOM's fused QKV's output_dim has the shape of
                 # (num_heads * 3 * head_size), while the

+ 12 - 21
aphrodite/modeling/models/chatglm.py

@@ -2,7 +2,7 @@
 # Adapted from
 # https://github.com/THUDM/ChatGLM2-6B
 """Inference-only ChatGLM model compatible with THUDM weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
@@ -10,6 +10,8 @@ from torch.nn import LayerNorm
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
@@ -21,11 +23,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (get_tensor_model_parallel_world_size)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs import ChatGLMConfig
 
 
@@ -287,8 +286,7 @@ class ChatGLMModel(nn.Module):
         super().__init__()
 
         self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
-                                                config.hidden_size,
-                                                linear_method=linear_method)
+                                                config.hidden_size)
 
         self.num_layers = config.num_layers
         self.multi_query_group_num = config.multi_query_group_num
@@ -296,8 +294,7 @@ class ChatGLMModel(nn.Module):
         self.encoder = GLMTransformer(config, linear_method)
 
         self.output_layer = ParallelLMHead(config.padded_vocab_size,
-                                           config.hidden_size,
-                                           linear_method=linear_method)
+                                           config.hidden_size)
 
     def forward(
         self,
@@ -343,8 +340,8 @@ class ChatGLMForCausalLM(nn.Module):
         self.config: ChatGLMConfig = config
         self.linear_method = linear_method
         self.transformer = ChatGLMModel(config, linear_method)
-        self.logits_processor = LogitsProcessor(config.padded_vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.transformer.output_layer.weight
+        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -360,8 +357,8 @@ class ChatGLMForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.transformer.output_layer,
-                                       hidden_states, sampling_metadata)
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+                                       sampling_metadata)
         return logits
 
     def sample(
@@ -372,15 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_pos_emb.inv_freq" in name:
                 continue
             if "word_embeddings" in name:

+ 47 - 129
aphrodite/modeling/models/cohere.py → aphrodite/modeling/models/commandr.py

@@ -20,7 +20,7 @@
 
 # This file is based on the LLama model definition file in transformers
 """PyTorch Cohere model."""
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 import torch.utils.checkpoint
@@ -29,33 +29,22 @@ from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    ParallelLMHead,
-    ParallelTWEHead,
-    VocabParallelEmbedding,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 @torch.compile
@@ -72,9 +61,9 @@ def layer_norm_func(hidden_states, weight, variance_epsilon):
 
 class LayerNorm(nn.Module):
 
-    def __init__(self, hidden_size, eps=1e-5, bias=False):
+    def __init__(self, param_shape=None, eps=1e-5):
         super().__init__()
-        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.weight = nn.Parameter(torch.ones(param_shape))
         self.variance_epsilon = eps
         set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
 
@@ -108,29 +97,12 @@ class CohereMLP(nn.Module):
         self.config = config
         self.hidden_size = config.hidden_size
         self.intermediate_size = config.intermediate_size
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                self.hidden_size,
-                [self.intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.gate_up_proj = MergedColumnParallelLinear(
+            self.hidden_size,
+            [self.intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.down_proj = RowParallelLinear(
             self.intermediate_size,
             self.hidden_size,
@@ -140,12 +112,7 @@ class CohereMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -185,37 +152,14 @@ class CohereAttention(nn.Module):
         self.rope_theta = config.rope_theta
         self.rope_scaling = getattr(config, "rope_scaling", None)
         self.use_qk_norm = getattr(config, "use_qk_norm", False)
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                self.hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.hidden_size,
@@ -237,10 +181,10 @@ class CohereAttention(nn.Module):
             num_kv_heads=self.num_kv_heads,
         )
         if self.use_qk_norm:
-            self.q_norm = LayerNorm(hidden_size=(self.num_heads,
+            self.q_norm = LayerNorm(param_shape=(self.num_heads,
                                                  self.head_dim),
                                     eps=config.layer_norm_eps)
-            self.k_norm = LayerNorm(hidden_size=(self.num_kv_heads,
+            self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
                                                  self.head_dim),
                                     eps=config.layer_norm_eps)
 
@@ -260,14 +204,8 @@ class CohereAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         if self.use_qk_norm:
             q, k = self._apply_qk_norm(q, k)
         q, k = self.rotary_emb(positions, q, k)
@@ -278,18 +216,16 @@ class CohereAttention(nn.Module):
 
 class CohereDecoderLayer(nn.Module):
 
-    def __init__(
-        self,
-        config: CohereConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: CohereConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
 
         self.self_attn = CohereAttention(config, linear_method=linear_method)
 
         self.mlp = CohereMLP(config, linear_method=linear_method)
-        self.input_layernorm = LayerNorm(config.hidden_size,
+        self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
                                          eps=config.layer_norm_eps)
 
     def forward(
@@ -327,13 +263,13 @@ class CohereModel(nn.Module):
         self.config = config
         self.vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+                                                   config.hidden_size)
         self.layers = nn.ModuleList([
             CohereDecoderLayer(config, linear_method=linear_method)
             for _ in range(config.num_hidden_layers)
         ])
-        self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.norm = LayerNorm(param_shape=(config.hidden_size),
+                              eps=config.layer_norm_eps)
 
     def forward(
         self,
@@ -367,16 +303,9 @@ class CohereForCausalLM(nn.Module):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
-        self.model = CohereModel(config, linear_method)
-        if self.config.tie_word_embeddings:
-            self.lm_head = ParallelTWEHead(self.model.embed_tokens)
-        else:
-            self.lm_head = ParallelLMHead(config.vocab_size,
-                                          config.hidden_size,
-                                          linear_method=linear_method)
         self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size,
                                                 scale=config.logit_scale)
+        self.model = CohereModel(config, linear_method)
         self.sampler = Sampler()
 
     @torch.no_grad()
@@ -393,8 +322,8 @@ class CohereForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
-                                       sampling_metadata)
+        logits = self.logits_processor(self.model.embed_tokens.weight,
+                                       hidden_states, sampling_metadata)
         return logits
 
     def sample(
@@ -405,13 +334,7 @@ class CohereForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -420,14 +343,9 @@ class CohereForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         loaded_params = set()
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             for param_name, shard_name, shard_id in stacked_params_mapping:
                 if shard_name not in name:
                     continue
@@ -440,9 +358,9 @@ class CohereForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # lm_head is not used as it is tied with embed_token.
-                # To prevent errors, skip loading lm_head.
-                if self.config.tie_word_embeddings and "lm_head" in name:
+                # lm_head is not used in Aphrodite as it is tied with
+                # embed_token. To prevent errors, skip loading lm_head.weight.
+                if "lm_head.weight" in name:
                     continue
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:

+ 22 - 33
aphrodite/modeling/models/dbrx.py

@@ -1,10 +1,14 @@
 # coding=utf-8
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               QKVParallelLinear,
@@ -15,14 +19,9 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
 
 
@@ -192,15 +191,12 @@ class DbrxAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=self.max_position,
             base=int(self.rope_theta),
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
 
         tp_world_size = get_tensor_model_parallel_world_size()
@@ -314,9 +310,10 @@ class DbrxModel(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          config.d_model,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(
+            config.vocab_size,
+            config.d_model,
+        )
         self.blocks = nn.ModuleList(
             [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
         self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
@@ -358,14 +355,14 @@ class DbrxForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.unpadded_vocab_size = config.vocab_size
         self.transformer = DbrxModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.d_model,
-                                      org_num_embeddings=config.vocab_size,
-                                      padding_size=DEFAULT_VOCAB_PADDING_SIZE,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(
-            self.unpadded_vocab_size,
-            min(config.vocab_size, config.tokenizer_vocab_size))
+        self.lm_head = ParallelLMHead(
+            config.vocab_size,
+            config.d_model,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+        )
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -381,7 +378,7 @@ class DbrxForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -393,21 +390,13 @@ class DbrxForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         expert_params_mapping = [(
             "ws" if weight_name in ["w1", "v1"] else "w2s",
             f"experts.mlp.{weight_name}",
         ) for weight_name in ["w1", "v1", "w2"]]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             for param_name, weight_name in expert_params_mapping:
                 if weight_name not in name:
                     continue

+ 5 - 12
aphrodite/modeling/models/decilm.py

@@ -1,8 +1,8 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 DeciAI Research Team. All rights reserved.
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -24,16 +24,15 @@
 # limitations under the License.
 """Inference-only DeciLM model compatible with HuggingFace weights."""
 
-from typing import Optional
+from typing import Iterable, Optional, Tuple
 
 import torch
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.layers.linear import LinearMethodBase
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaForCausalLM
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
 
 
 class DeciLMForCausalLM(LlamaForCausalLM):
@@ -66,11 +65,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
                          linear_method=linear_method,
                          lora_config=lora_config)
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -80,9 +75,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 75 - 178
aphrodite/modeling/models/deepseek.py

@@ -22,32 +22,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Deepseek model."""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
-import numpy as np
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.fused_moe import fused_topk
+from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear,
-    ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class DeepseekMLP(nn.Module):
@@ -82,39 +82,6 @@ class DeepseekMLP(nn.Module):
         return x
 
 
-class DeepseekExpertMLP(nn.Module):
-
-    def __init__(
-        self,
-        hidden_size: int,
-        intermediate_size: int,
-        hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
-        super().__init__()
-        self.gate_proj = ReplicatedLinear(hidden_size,
-                                          intermediate_size,
-                                          bias=False,
-                                          linear_method=linear_method)
-        self.up_proj = ReplicatedLinear(hidden_size,
-                                        intermediate_size,
-                                        bias=False,
-                                        linear_method=linear_method)
-        self.down_proj = ReplicatedLinear(intermediate_size,
-                                          hidden_size,
-                                          bias=False,
-                                          linear_method=linear_method)
-        self.act_fn = nn.SiLU()
-
-    def forward(self, hidden_states):
-        gate_out, _ = self.gate_proj(hidden_states)
-        gate_out = self.act_fn(gate_out)
-        up_out, _ = self.up_proj(hidden_states)
-        current_hidden_states = gate_out * up_out
-        current_hidden_states, _ = self.down_proj(current_hidden_states)
-        return current_hidden_states
-
-
 class DeepseekMoE(nn.Module):
 
     def __init__(
@@ -128,44 +95,20 @@ class DeepseekMoE(nn.Module):
         self.tp_size = get_tensor_model_parallel_world_size()
         self.n_routed_experts = config.n_routed_experts
         self.top_k = config.num_experts_per_tok
-        self.linear_method = linear_method
-        if self.linear_method is None:
-            self.linear_method = UnquantizedLinearMethod()
-
-        if not isinstance(
-                self.linear_method, UnquantizedLinearMethod
-        ) and not self.linear_method.quant_config.support_fused_moe():
-            if self.tp_size > self.n_routed_experts:
-                raise ValueError(
-                    f"Tensor parallel size {self.tp_size} is greater than "
-                    f"the number of experts {self.n_routed_experts}.")
-            # Split experts equally between ranks
-            self.expert_indicies = np.array_split(range(
-                self.n_routed_experts), self.tp_size)[self.rank].tolist()
-            if not self.expert_indicies:
-                raise ValueError(
-                    f"Rank {self.rank} has no experts assigned to it.")
-
-            self.experts = nn.ModuleList([
-                DeepseekExpertMLP(
-                    hidden_size=config.hidden_size,
-                    intermediate_size=config.moe_intermediate_size,
-                    hidden_act=config.hidden_act,
-                    linear_method=linear_method,
-                ) if idx in self.expert_indicies else None
-                for idx in range(self.n_routed_experts)
-            ])
-        else:
-            self.w1 = MergedColumnParallelLinear(
-                config.hidden_size, [config.moe_intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-                num_experts=self.n_routed_experts)
-            self.w2 = RowParallelLinear(config.moe_intermediate_size,
-                                        config.hidden_size,
-                                        bias=False,
-                                        linear_method=linear_method,
-                                        num_experts=self.n_routed_experts)
+        if self.tp_size > self.n_routed_experts:
+            raise ValueError(
+                f"Tensor parallel size {self.tp_size} is greater than "
+                f"the number of experts {self.n_routed_experts}.")
+
+        self.experts = nn.ModuleList([
+            DeepseekMLP(hidden_size=config.hidden_size,
+                        intermediate_size=config.moe_intermediate_size,
+                        hidden_act=config.hidden_act,
+                        linear_method=linear_method,
+                        reduce_results=False)
+            for idx in range(self.n_routed_experts)
+        ])
+        self.pack_params()
 
         self.gate = ReplicatedLinear(config.hidden_size,
                                      self.n_routed_experts,
@@ -183,6 +126,25 @@ class DeepseekMoE(nn.Module):
                 reduce_results=False,
             )
 
+    def pack_params(self):
+        w1 = []
+        w2 = []
+        for expert in self.experts:
+            w1.append(expert.gate_up_proj.weight)
+            w2.append(expert.down_proj.weight)
+        self.w1 = torch._utils._flatten_dense_tensors(w1)
+        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
+        for data, param in zip(w1s, w1):
+            param.data = data
+        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
+
+        self.w2 = torch._utils._flatten_dense_tensors(w2)
+        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
+        for data, param in zip(w2s, w2):
+            param.data = data
+
+        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
+
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
@@ -190,36 +152,13 @@ class DeepseekMoE(nn.Module):
             shared_output = self.shared_experts(hidden_states)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-
-        if not isinstance(
-                self.linear_method, UnquantizedLinearMethod
-        ) and not self.linear_method.quant_config.support_fused_moe():
-            routing_weights, selected_experts = fused_topk(
-                router_logits,
-                self.top_k,
-                renormalize=self.config.norm_topk_prob)
-            final_hidden_states = None
-            for expert_idx in self.expert_indicies:
-                expert_layer = self.experts[expert_idx]
-                expert_mask = (selected_experts == expert_idx)
-                expert_weights = (routing_weights * expert_mask).sum(
-                    dim=-1, keepdim=True)
-
-                current_hidden_states = expert_layer(hidden_states).mul_(
-                    expert_weights)
-                if final_hidden_states is None:
-                    final_hidden_states = current_hidden_states
-                else:
-                    final_hidden_states.add_(current_hidden_states)
-        else:
-            final_hidden_states = self.linear_method.apply_moe_weights(
-                self.w1.linear_weights,
-                self.w2.linear_weights,
-                hidden_states,
-                router_logits,
-                self.top_k,
-                renormalize=self.config.norm_topk_prob,
-            )
+        final_hidden_states = fused_moe(hidden_states,
+                                        self.w1,
+                                        self.w2,
+                                        router_logits,
+                                        self.top_k,
+                                        renormalize=self.config.norm_topk_prob,
+                                        inplace=True)
 
         if self.config.n_shared_experts is not None:
             final_hidden_states = final_hidden_states + shared_output
@@ -377,6 +316,8 @@ class DeepseekDecoderLayer(nn.Module):
 
 class DeepseekModel(nn.Module):
 
+    fall_back_to_pt_during_load = False
+
     def __init__(
         self,
         config: PretrainedConfig,
@@ -386,9 +327,10 @@ class DeepseekModel(nn.Module):
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
 
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.layers = nn.ModuleList([
             DeepseekDecoderLayer(config,
                                  layer_idx,
@@ -427,8 +369,7 @@ class DeepseekForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.model = DeepseekModel(config, linear_method)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -444,7 +385,7 @@ class DeepseekForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -456,41 +397,18 @@ class DeepseekForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
-            ("mlp.gate_up_proj", "mlp.gate_proj", 0),
-            ("mlp.gate_up_proj", "mlp.up_proj", 1),
-            ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0),
-            ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
         ]
 
-        expert_params_mapping = [
-            # (param_name, weight_name, shard_id, expert_id)
-            ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2",
-             f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
-            for expert_id in range(self.config.n_routed_experts)
-            for weight_name, shard_id in [("gate_proj",
-                                           0), ("up_proj",
-                                                1), ("down_proj", None)]
-        ] if self.linear_method is None or (
-            self.linear_method.quant_config.support_fused_moe()) else []
-
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path,
-                cache_dir,
-                load_format,
-                revision,
-                self.config,
-                fall_back_to_pt=False):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
@@ -509,35 +427,14 @@ class DeepseekForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                for (param_name, weight_name, shard_id,
-                     expert_id) in expert_params_mapping:
-                    if weight_name not in name:
-                        continue
-                    name = name.replace(weight_name, param_name)
-                    if name.endswith(".bias") and name not in params_dict:
-                        continue
-                    param = params_dict[name]
-                    weight_loader = param.weight_loader
-                    if shard_id is None:
-                        weight_loader(param,
-                                      loaded_weight,
-                                      expert_id=expert_id)
-                    else:
-                        weight_loader(param,
-                                      loaded_weight,
-                                      shard_id,
-                                      expert_id=expert_id)
-                    break
-                else:
-                    # Skip loading extra bias for GPTQ models.
-                    if name.endswith(".bias") and name not in params_dict:
-                        continue
-                    # Skip experts that are not assigned to this worker.
-                    if (("mlp.experts." in name
-                         or "mlp.shared_experts." in name)
-                            and name not in params_dict):
-                        continue
-                    param = params_dict[name]
-                    weight_loader = getattr(param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(param, loaded_weight)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Skip experts that are not assigned to this worker.
+                if (("mlp.experts." in name or "mlp.shared_experts." in name)
+                        and name not in params_dict):
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 17 - 24
aphrodite/modeling/models/falcon.py

@@ -1,6 +1,7 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights
 # reserved.
@@ -19,7 +20,7 @@
 """PyTorch Falcon model."""
 
 import math
-from typing import List, Optional, Union
+from typing import Iterable, List, Optional, Tuple, Union
 
 import torch
 from torch import nn
@@ -27,6 +28,10 @@ from torch.nn import LayerNorm
 from transformers import FalconConfig as HF_FalconConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
@@ -35,15 +40,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size,
-                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs import RWConfig
 
 FalconConfig = Union[HF_FalconConfig, RWConfig]
@@ -322,7 +322,9 @@ class FalconModel(nn.Module):
 
         # Embedding + LN Embedding
         self.word_embeddings = VocabParallelEmbedding(
-            config.vocab_size, self.embed_dim, linear_method=linear_method)
+            config.vocab_size,
+            self.embed_dim,
+        )
 
         # Transformer blocks
         self.h = nn.ModuleList([
@@ -364,11 +366,8 @@ class FalconForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.transformer = FalconModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -388,7 +387,7 @@ class FalconForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -400,11 +399,7 @@ class FalconForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         total_num_heads = self.config.num_attention_heads
         if self.config.new_decoder_architecture:
             total_num_kv_heads = self.config.num_kv_heads
@@ -414,9 +409,7 @@ class FalconForCausalLM(nn.Module):
             total_num_kv_heads = total_num_heads
         num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if name == "lm_head.weight":
                 # Falcon uses tied embeddings.
                 continue

+ 95 - 154
aphrodite/modeling/models/gemma.py

@@ -16,37 +16,30 @@
 # limitations under the License.
 """Inference-only Gemma model compatible with HuggingFace weights."""
 from functools import lru_cache
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
+from loguru import logger
 from torch import nn
 from transformers import GemmaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-    ColumnParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 @lru_cache(maxsize=None)
@@ -56,8 +49,16 @@ def _get_gemma_act_fn(
 ) -> nn.Module:
     if hidden_activation is None:
         if hidden_act is not None:
-            hidden_activation = hidden_act
-        return GeluAndMul(approximate="none")
+            logger.warning(
+                "Gemma's activation function was incorrectly set to exact GeLU "
+                "in the config JSON file when it was initially released. "
+                "Changing the activation function to approximate GeLU "
+                "(`gelu_pytorch_tanh`). If you want to use the legacy "
+                f"`{hidden_act}`, edit the config JSON to set "
+                f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
+                "See https://github.com/huggingface/transformers/pull/29402 "
+                "for more details.")
+        return GeluAndMul(approximate="tanh")
     elif hidden_activation == "gelu_pytorch_tanh":
         return GeluAndMul(approximate="tanh")
     elif hidden_activation == "gelu":
@@ -78,44 +79,18 @@ class GemmaMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.down_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
         self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -123,16 +98,14 @@ class GemmaMLP(nn.Module):
 
 class GemmaAttention(nn.Module):
 
-    def __init__(
-        self,
-        hidden_size: int,
-        num_heads: int,
-        num_kv_heads: int,
-        head_dim: int,
-        max_position_embeddings: int = 8192,
-        rope_theta: float = 10000,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 hidden_size: int,
+                 num_heads: int,
+                 num_kv_heads: int,
+                 head_dim: int,
+                 max_position_embeddings: int = 8192,
+                 rope_theta: float = 10000,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -155,59 +128,32 @@ class GemmaAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
+
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=self.rope_theta,
-            is_neox_style=is_neox_style,
-        )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
+            is_neox_style=True,
         )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,
@@ -216,14 +162,8 @@ class GemmaAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
@@ -299,9 +239,10 @@ class GemmaModel(nn.Module):
         super().__init__()
         self.config = config
 
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.layers = nn.ModuleList([
             GemmaDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -340,21 +281,41 @@ class GemmaModel(nn.Module):
 
 
 class GemmaForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+    ]
+    # Gemma does not apply LoRA to the embedding layer.
+    embedding_modules = {}
+    embedding_padding_modules = []
 
     def __init__(
         self,
         config: GemmaConfig,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
+        del lora_config  # Unused.
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.model = GemmaModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     @torch.no_grad()
@@ -371,8 +332,8 @@ class GemmaForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
-                                       sampling_metadata)
+        logits = self.logits_processor(self.model.embed_tokens.weight,
+                                       hidden_states, sampling_metadata)
         return logits
 
     def sample(
@@ -383,13 +344,7 @@ class GemmaForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -398,29 +353,13 @@ class GemmaForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         loaded_params = set()
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
-            if "rotary_emb.inv_freq" in name:
-                continue
-            if "embed_tokens" in name:
-                # Copy word embedding to lm_head
-                head_name = name.replace("model.embed_tokens", "lm_head")
-                if head_name in params_dict:
-                    loaded_params.add(head_name)
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
-            for param_name, weight_name, shard_id in stacked_params_mapping:
-                if weight_name not in name:
+        for name, loaded_weight in weights:
+            for (param_name, shard_name, shard_id) in stacked_params_mapping:
+                if shard_name not in name:
                     continue
-                name = name.replace(weight_name, param_name)
+                name = name.replace(shard_name, param_name)
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
@@ -429,10 +368,12 @@ class GemmaForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
+                # lm_head is not used in Aphrodite as it is tied with
+                # embed_token. To prevent errors, skip loading lm_head.weight.
                 if "lm_head.weight" in name:
                     continue
-                # Skip loading extra layer for lora models.
-                if "lm_head" in name and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 # GemmaRMSNorm is different from Llama's in that it multiplies
                 # (1 + weight) to the output, instead of just weight.

+ 13 - 34
aphrodite/modeling/models/gpt2.py

@@ -18,13 +18,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-2 model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import GPT2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
@@ -32,13 +34,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.distributed import (get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class GPT2Attention(nn.Module):
@@ -173,9 +172,7 @@ class GPT2Model(nn.Module):
         assert not config.scale_attn_by_inverse_layer_idx
         assert not config.reorder_and_upcast_attn
         self.embed_dim = config.hidden_size
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          self.embed_dim,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
             GPT2Block(config, linear_method)
@@ -213,12 +210,8 @@ class GPT2LMHeadModel(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.transformer = GPT2Model(config, linear_method)
-        # self.lm_head_weight = self.transformer.wte.weight
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.transformer.wte.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -234,7 +227,7 @@ class GPT2LMHeadModel(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -246,27 +239,13 @@ class GPT2LMHeadModel(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
-            if "lm_head" in name and name not in params_dict:
+        for name, loaded_weight in weights:
+            if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
                 # linear layer.
                 continue
-            if "wte" in name:
-                # Copy word embedding to lm_head
-                head_name = name.replace("transformer.wte", "lm_head")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
             if ".attn.bias" in name or ".attn.masked_bias" in name:
                 # Skip attention mask.
                 # NOTE: "c_attn.bias" should not be skipped.

+ 14 - 34
aphrodite/modeling/models/gpt_bigcode.py

@@ -1,6 +1,7 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 CTranslate2, and Michael Feil
 # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
@@ -18,13 +19,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPTBigCode model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import GPTBigCodeConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
@@ -32,13 +35,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.distributed import (get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class GPTBigCodeAttention(nn.Module):
@@ -193,9 +193,7 @@ class GPTBigCodeModel(nn.Module):
 
         self.embed_dim = config.hidden_size
 
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          self.embed_dim,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
             GPTBigCodeBlock(config, linear_method)
@@ -233,12 +231,8 @@ class GPTBigCodeForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.transformer = GPTBigCodeModel(config, linear_method)
-        # self.lm_head_weight = self.transformer.wte.weight
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.transformer.wte.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -254,7 +248,7 @@ class GPTBigCodeForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -266,25 +260,11 @@ class GPTBigCodeForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
-            if "lm_head" in name and name not in params_dict:
+        for name, loaded_weight in weights:
+            if "lm_head.weight" in name:
                 continue
-            if "wte" in name:
-                # Copy word embedding to lm_head
-                head_name = name.replace("transformer.wte", "lm_head")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
             if ".attn.bias" in name:
                 # Skip attention mask.
                 # NOTE: "c_attn.bias" should not be skipped.

+ 28 - 75
aphrodite/modeling/models/gpt_j.py

@@ -1,6 +1,7 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
 #
@@ -16,35 +17,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-J model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import GPTJConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    ParallelLMHead,
-    VocabParallelEmbedding,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class GPTJAttention(nn.Module):
@@ -59,36 +52,13 @@ class GPTJAttention(nn.Module):
         self.hidden_size = config.hidden_size
         self.head_size = self.hidden_size // self.total_num_heads
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                config.hidden_size,
-                config.hidden_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                config.hidden_size,
-                config.hidden_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                config.hidden_size,
-                config.hidden_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                config.hidden_size,
-                self.head_size,
-                self.total_num_heads,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            config.hidden_size,
+            self.head_size,
+            self.total_num_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.out_proj = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
@@ -122,13 +92,8 @@ class GPTJAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.chunk(chunks=3, dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
         q, k = self.rotary_emb(position_ids, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         attn_output, _ = self.out_proj(attn_output)
@@ -210,9 +175,10 @@ class GPTJModel(nn.Module):
         super().__init__()
         self.config = config
         self.embed_dim = config.n_embd
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          self.embed_dim,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(
+            config.vocab_size,
+            self.embed_dim,
+        )
         self.h = nn.ModuleList(
             [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -253,10 +219,8 @@ class GPTJForCausalLM(nn.Module):
             config.vocab_size,
             config.n_embd,
             bias=True,
-            linear_method=linear_method,
         )
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -272,7 +236,7 @@ class GPTJForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits
 
@@ -284,13 +248,7 @@ class GPTJForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -299,16 +257,11 @@ class GPTJForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "attn.bias" in name or "attn.masked_bias" in name:
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 22 - 42
aphrodite/modeling/models/gpt_neox.py

@@ -17,36 +17,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
-
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import GPTNeoXConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class GPTNeoXAttention(nn.Module):
@@ -87,15 +78,11 @@ class GPTNeoXAttention(nn.Module):
         rope_theta = getattr(config, "rope_theta", 10000)
         max_position_embeddings = getattr(config, "max_position_embeddings",
                                           8192)
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
-            is_neox_style=is_neox_style,
         )
         self.attn = Attention(self.num_heads, self.head_size, scaling)
 
@@ -201,9 +188,10 @@ class GPTNeoXModel(nn.Module):
         super().__init__()
         self.config = config
 
-        self.embed_in = VocabParallelEmbedding(config.vocab_size,
-                                               config.hidden_size,
-                                               linear_method=linear_method)
+        self.embed_in = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.layers = nn.ModuleList([
             GPTNeoXLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -242,11 +230,11 @@ class GPTNeoXForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.gpt_neox = GPTNeoXModel(config, linear_method)
-        self.embed_out = ParallelLMHead(config.vocab_size,
-                                        config.hidden_size,
-                                        linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.embed_out = ParallelLMHead(
+            config.vocab_size,
+            config.hidden_size,
+        )
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -262,7 +250,7 @@ class GPTNeoXForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.embed_out, hidden_states,
+        logits = self.logits_processor(self.embed_out.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -274,17 +262,9 @@ class GPTNeoXForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if ("attention.bias" in name or "attention.masked_bias" in name
                     or "rotary_emb.inv_freq" in name):
                 continue

+ 31 - 84
aphrodite/modeling/models/internlm2.py

@@ -1,35 +1,26 @@
 # -*- coding: utf-8 -*-
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    ColumnParallelLinear,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class InternLM2MLP(nn.Module):
@@ -42,47 +33,21 @@ class InternLM2MLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.w1 = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.w3 = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.w2 = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.w2 = RowParallelLinear(intermediate_size,
+                                    hidden_size,
+                                    bias=False,
+                                    linear_method=linear_method)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.w2(x)
         return x
@@ -145,12 +110,10 @@ class InternLM2Attention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,
@@ -241,7 +204,6 @@ class InternLM2Model(nn.Module):
         self.tok_embeddings = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
         )
         self.layers = nn.ModuleList([
             InternLMDecoderLayer(config, linear_method)
@@ -282,13 +244,8 @@ class InternLM2ForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = InternLM2Model(config, linear_method)
-        self.output = ParallelLMHead(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -304,7 +261,7 @@ class InternLM2ForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.output, hidden_states,
+        logits = self.logits_processor(self.output.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -316,27 +273,17 @@ class InternLM2ForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w1", 0),
             ("gate_up_proj", "w3", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
@@ -366,9 +313,9 @@ class InternLM2ForCausalLM(nn.Module):
                     wk = wk.reshape(-1, wk.shape[-1])
                     wv = wv.reshape(-1, wv.shape[-1])
                     weight_loader = param.weight_loader
-                    weight_loader(param, wq, "q")
-                    weight_loader(param, wk, "k")
-                    weight_loader(param, wv, "v")
+                    weight_loader(param, wq, 'q')
+                    weight_loader(param, wk, 'k')
+                    weight_loader(param, wv, 'v')
                 else:
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)

+ 333 - 0
aphrodite/modeling/models/jais.py

@@ -0,0 +1,333 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
+# reserved.
+# Copyright 2023 Cerebras Systems.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Jais model compatible with HuggingFace weights."""
+
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.transformers_utils.configs import JAISConfig
+
+
+class SwiGLUActivation(nn.Module):
+
+    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+        return x1 * nn.functional.silu(x2)
+
+
+def _get_alibi_slopes(n):
+
+    def get_slopes_power_of_2(n):
+        start = 2**(-(2**-(math.log2(n) - 3)))
+        ratio = start
+        return [start * ratio**i for i in range(n)]
+
+    if math.log2(n).is_integer():
+        return get_slopes_power_of_2(n)
+    else:
+        closest_power_of_2 = 2**math.floor(math.log2(n))
+        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
+            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
+
+
+class JAISAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: JAISConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        total_num_heads = config.num_attention_heads
+        tensor_model_parallel_world_size = (
+            get_tensor_model_parallel_world_size())
+        assert total_num_heads % tensor_model_parallel_world_size == 0
+        self.num_heads = total_num_heads // tensor_model_parallel_world_size
+        self.head_dim = self.hidden_size // total_num_heads
+        if hasattr(config, "scale_qk_dot_by_d"):
+            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
+        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
+        self.scale = self.head_dim**-self.attn_scale_power
+
+        self.c_attn = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        tp_rank = get_tensor_model_parallel_rank()
+        head_start = tp_rank * self.num_heads
+        head_end = (tp_rank + 1) * self.num_heads
+        alibi_slopes = _get_alibi_slopes(total_num_heads)
+        alibi_slopes = alibi_slopes[head_start:head_end]
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            scale=self.scale,
+            alibi_slopes=alibi_slopes,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.c_attn(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        attn_output, _ = self.c_proj(attn_output)
+        return attn_output
+
+
+class JAISMLP(nn.Module):
+
+    def __init__(
+        self,
+        intermediate_size: int,
+        config: JAISConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        self.swiglu = config.activation_function == "swiglu"
+        self.c_fc = ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+        self.c_fc2 = (ColumnParallelLinear(
+            hidden_size,
+            intermediate_size,
+            bias=True,
+            linear_method=linear_method,
+        ) if self.swiglu else None)
+        self.c_proj = RowParallelLinear(
+            intermediate_size,
+            hidden_size,
+            bias=True,
+            linear_method=linear_method,
+        )
+
+        self.act = SwiGLUActivation()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        if self.swiglu:
+            hidden_states2, _ = self.c_fc2(hidden_states)
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = (self.act(hidden_states, hidden_states2)
+                         if self.swiglu else self.act(hidden_states))
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class JAISBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: JAISConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        hidden_size = config.hidden_size
+        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
+                     hidden_size)
+
+        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.attn = JAISAttention(config, linear_method)
+        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.mlp = JAISMLP(inner_dim, config, linear_method)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_output = self.attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        # residual connection
+        hidden_states = attn_output + residual
+
+        residual = hidden_states
+        hidden_states = self.ln_2(hidden_states)
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        # residual connection
+        hidden_states = residual + feed_forward_hidden_states
+        return hidden_states
+
+
+class JAISModel(nn.Module):
+
+    def __init__(
+        self,
+        config: JAISConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        assert not config.add_cross_attention
+        assert not config.scale_attn_by_inverse_layer_idx
+        assert not config.reorder_and_upcast_attn
+        self.embed_dim = config.hidden_size
+        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
+        self.wpe = (nn.Embedding(config.max_position_embeddings,
+                                 self.embed_dim)
+                    if config.position_embedding_type != "alibi" else None)
+        if hasattr(config, "embeddings_scale"):
+            self.embeddings_scale = config.embeddings_scale
+        else:
+            self.embeddings_scale = config.mup_embeddings_scale
+        self.h = nn.ModuleList([
+            JAISBlock(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        inputs_embeds = self.wte(input_ids)
+        if self.wpe is not None:
+            position_embeds = self.wpe(position_ids)
+            hidden_states = inputs_embeds + position_embeds
+        else:
+            hidden_states = inputs_embeds
+        hidden_states *= torch.tensor(float(self.embeddings_scale),
+                                      dtype=hidden_states.dtype)
+
+        for i in range(len(self.h)):
+            layer = self.h[i]
+            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
+
+        hidden_states = self.ln_f(hidden_states)
+        return hidden_states
+
+
+class JAISLMHeadModel(nn.Module):
+
+    def __init__(
+        self,
+        config: JAISConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.transformer = JAISModel(config, linear_method)
+        self.lm_head_weight = self.transformer.wte.weight
+        if hasattr(config, "width_scale"):
+            self.output_logits_scale = config.width_scale
+        else:
+            self.output_logits_scale = (config.mup_output_alpha *
+                                        config.mup_width_scale)
+        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
+                                                scale=self.output_logits_scale)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "lm_head.weight" in name:
+                # GPT-2 ties the weights of the embedding layer and the final
+                # linear layer.
+                continue
+            if ".attn.bias" in name or ".attn.masked_bias" in name:
+                # Skip attention mask.
+                # NOTE: "c_attn.bias" should not be skipped.
+                continue
+            if "relative_pe" in name:
+                continue
+            if not name.startswith("transformer."):
+                name = "transformer." + name
+            param = params_dict[name]
+            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
+            # Because of this, we need to transpose the weights.
+            # Note(zhuohan): the logic below might break quantized models.
+            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
+                if conv1d_weight_name not in name:
+                    continue
+                if not name.endswith(".weight"):
+                    continue
+                loaded_weight = loaded_weight.t()
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 39 - 124
aphrodite/modeling/models/llama.py

@@ -1,6 +1,7 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -21,8 +22,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only LLaMA model compatible with HuggingFace weights."""
-
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
@@ -32,13 +32,11 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.utils import is_hip
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator,
-                                              kv_cache_scales_loader)
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
@@ -47,8 +45,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.model_loader.weight_utils import (
+    default_weight_loader, kv_cache_scales_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 
 
@@ -62,47 +60,21 @@ class LlamaMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.down_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -154,37 +126,14 @@ class LlamaAttention(nn.Module):
         # scaling_factor = tensor_amax / FPtype_max
         self.kv_scale = 1.0
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_heads * self.head_dim,
-                bias=bias,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=bias,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=bias,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=bias,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=bias,
+            linear_method=linear_method,
+        )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
@@ -192,24 +141,18 @@ class LlamaAttention(nn.Module):
             linear_method=linear_method,
         )
 
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
             rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
-        )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            sliding_window=sliding_window,
         )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=sliding_window)
 
     def forward(
         self,
@@ -217,16 +160,9 @@ class LlamaAttention(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
-        # kv_quant_param: List[float],
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
                                 self.kv_scale)
@@ -274,10 +210,6 @@ class LlamaDecoderLayer(nn.Module):
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                 eps=config.rms_norm_eps)
-        if config.model_type == "Yi":
-            # Some old Yi finetunes and quants have not been llama-fied
-            self.ln1 = self.input_layernorm
-            self.ln2 = self.post_attention_layernorm
 
     def forward(
         self,
@@ -286,7 +218,6 @@ class LlamaDecoderLayer(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
-        # kv_quant_param: List[float],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
@@ -300,7 +231,6 @@ class LlamaDecoderLayer(nn.Module):
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             attn_metadata=attn_metadata,
-            # kv_quant_param=kv_quant_param,
         )
 
         # Fully Connected
@@ -321,14 +251,13 @@ class LlamaModel(nn.Module):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
-        lora_vocab = ((lora_config.lora_extra_vocab_size *
-                       (lora_config.max_loras or 1)) if lora_config else 0)
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
         self.vocab_size = config.vocab_size + lora_vocab
         self.org_vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(
             self.vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
@@ -361,8 +290,6 @@ class LlamaModel(nn.Module):
                 kv_caches[i],
                 attn_metadata,
                 residual,
-                # attn_metadata.kv_quant_params[i]
-                # if attn_metadata.kv_quant_params is not None else None,
             )
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
@@ -413,16 +340,15 @@ class LlamaForCausalLM(nn.Module):
             self.unpadded_vocab_size,
             config.hidden_size,
             org_num_embeddings=config.vocab_size,
-            linear_method=linear_method,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
+
         logit_scale = getattr(config, "logit_scale", 1.0)
-        self.logits_processor = LogitsProcessor(
-            self.unpadded_vocab_size,
-            min(config.vocab_size, config.tokenizer_vocab_size), logit_scale)
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size, logit_scale)
         self.sampler = Sampler()
 
     def forward(
@@ -438,7 +364,7 @@ class LlamaForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -450,13 +376,7 @@ class LlamaForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -465,13 +385,8 @@ class LlamaForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name
@@ -479,7 +394,7 @@ class LlamaForCausalLM(nn.Module):
                 # Models trained using ColossalAI may include these tensors in
                 # the checkpoint. Skip them.
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 16 - 19
aphrodite/modeling/models/llava.py

@@ -1,23 +1,22 @@
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
-# TODO: We should port CLIPVisionModel's code over to not depend on
+# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
 # transformers' impl.
 from transformers import CLIPVisionModel, LlavaConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import VisionLanguageConfig
+from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -25,7 +24,7 @@ _KEYS_TO_MODIFY_MAPPING = {
 }
 
 
-# TODO: Run benchmark and decide if TP.
+# TODO(xwjiang): Run benchmark and decide if TP.
 class LlavaMultiModalProjector(nn.Module):
 
     def __init__(self, vision_hidden_size: int, text_hidden_size: int,
@@ -92,9 +91,8 @@ class LlavaForConditionalGeneration(nn.Module):
             config.text_config.hidden_size,
             org_num_embeddings=self.language_model.org_vocab_size)
         logit_scale = getattr(config, "logit_scale", 1.0)
-        self.logits_processor = LogitsProcessor(
-            self.unpadded_vocab_size,
-            min(config.vocab_size, config.tokenizer_vocab_size), logit_scale)
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size, logit_scale)
         self.sampler = Sampler()
 
     def forward(
@@ -106,6 +104,7 @@ class LlavaForConditionalGeneration(nn.Module):
         image_input: Optional[torch.Tensor] = None
     ) -> SamplerOutput:  # noqa: E501
         """Run forward pass for Llava 1.5.
+
         One key thing to understand is the `input_ids` already accounts for the
         positions of the to-be-inserted image embeddings.
         Concretely, consider a text prompt:
@@ -120,8 +119,10 @@ class LlavaForConditionalGeneration(nn.Module):
         9047, 13566, 29901].
         There will be 576 `32000` in the `input_ids`.
         (32000 is the token id for `<image>`.)
+
         This way, the `positions` and `attn_metadata` are consistent
         with the `input_ids`.
+
         The model takes two types of image inputs: 
         PIXEL_VALUES and IMAGE_FEATURES.
         The following shows how each maps to huggingface implementation.
@@ -130,6 +131,7 @@ class LlavaForConditionalGeneration(nn.Module):
         IMAGE_FEATURES:
         - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
         before going through the multi modal projector.
+
         Args:
             input_ids: Flattened (concatenated) input_ids corresponding to a
                 batch.
@@ -145,11 +147,11 @@ class LlavaForConditionalGeneration(nn.Module):
                     f"plus "
                     f"{self.vision_language_config.image_input_shape[1:]}."
                     f" You supplied {image_input.shape}. "
-                    f"If you are using Aphrodite's endpoint, make sure your "
+                    f"If you are using vLLM's entrypoint, make sure your "
                     f"supplied image input is consistent with "
                     f"image_input_shape in engine args.")
             if self.vision_tower is not None:
-                # TODO: Maybe port minimal CLIPVisionModel over.
+                # TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
                 image_outputs = self.vision_tower(image_input,
                                                   output_hidden_states=True)
                 image_features = image_outputs.hidden_states[
@@ -183,7 +185,7 @@ class LlavaForConditionalGeneration(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -195,11 +197,7 @@ class LlavaForConditionalGeneration(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # only doing this for language model part for now.
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
@@ -210,8 +208,7 @@ class LlavaForConditionalGeneration(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():

+ 531 - 0
aphrodite/modeling/models/minicpm.py

@@ -0,0 +1,531 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only MiniCPM model compatible with HuggingFace weights."""
+import math
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
+
+
+class MiniCPMMoE(nn.Module):
+    """A tensor-parallel MoE implementation that shards each expert
+    across all ranks.
+
+    Each expert's weights are sharded across all ranks and a fused MoE
+    kernel is used for the forward pass, and finally we reduce the outputs
+    across ranks.
+    """
+
+    def __init__(
+        self,
+        num_experts: int,
+        top_k: int,
+        hidden_size: int,
+        intermediate_size: int,
+        params_dtype: Optional[torch.dtype] = None,
+        tp_size: Optional[int] = None,
+    ):
+        super().__init__()
+        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
+        self.num_total_experts = num_experts
+        self.top_k = top_k
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size // self.tp_size
+
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+        self.params_dtype = params_dtype
+
+        self.gate = ReplicatedLinear(self.hidden_size,
+                                     self.num_total_experts,
+                                     bias=False,
+                                     params_dtype=self.params_dtype,
+                                     linear_method=None)
+
+        self.ws = nn.Parameter(
+            torch.empty(self.num_total_experts,
+                        2 * self.intermediate_size,
+                        self.hidden_size,
+                        device="cuda",
+                        dtype=self.params_dtype))
+        self.w2s = nn.Parameter(
+            torch.empty(self.num_total_experts,
+                        self.hidden_size,
+                        self.intermediate_size,
+                        device="cuda",
+                        dtype=self.params_dtype))
+
+        set_weight_attrs(self.ws, {
+            "weight_loader": self.weight_loader,
+        })
+        set_weight_attrs(self.w2s, {
+            "weight_loader": self.weight_loader,
+        })
+
+    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
+                      weight_name: str, expert_id: int):
+        tp_rank = get_tensor_model_parallel_rank()
+        param_data = param.data
+        shard_size = self.intermediate_size
+        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+        if weight_name.endswith("w1.weight"):
+            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
+        if weight_name.endswith("w3.weight"):
+            param_data[expert_id,
+                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
+        if weight_name.endswith("w2.weight"):
+            param_data[expert_id, :, :] = loaded_weight[:, shard]
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        num_tokens, hidden_size = hidden_states.shape
+        hidden_states = hidden_states.view(-1, self.hidden_size)
+        # router_logits: (num_tokens, n_experts)
+        router_logits, _ = self.gate(hidden_states)
+        final_hidden_states = fused_moe(hidden_states,
+                                        self.ws,
+                                        self.w2s,
+                                        router_logits,
+                                        self.top_k,
+                                        renormalize=True,
+                                        inplace=True)
+
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
+
+        return final_hidden_states.view(num_tokens, hidden_size)
+
+
+class MiniCPMMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class MiniCPMAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        # set rope as fp32 instead of bf16
+        self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
+        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        orig_dtype = q.dtype
+        q, k = q.float(), k.float()
+        q, k = self.rotary_emb(positions, q, k)
+        q, k = q.to(orig_dtype), k.to(orig_dtype)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class MiniCPMDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.self_attn = MiniCPMAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.num_experts = getattr(self.config, "num_experts", 0)
+        if self.num_experts == 0:
+            self.mlp = MiniCPMMLP(
+                hidden_size=self.hidden_size,
+                intermediate_size=config.intermediate_size,
+                hidden_act=config.hidden_act,
+                linear_method=linear_method,
+            )
+        else:
+            self.mlp = MiniCPMMoE(num_experts=config.num_experts,
+                                  top_k=config.num_experts_per_tok,
+                                  hidden_size=config.hidden_size,
+                                  intermediate_size=config.intermediate_size)
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states * \
+            (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states * \
+            (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
+
+        return hidden_states, None
+
+
+class MiniCPMModel(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+        self.embed_tokens = VocabParallelEmbedding(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+        )
+        self.layers = nn.ModuleList([
+            MiniCPMDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        embedding = self.embed_tokens(input_ids)
+        return embedding * self.config.scale_emb
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.get_input_embeddings(input_ids)
+        residual = None
+
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+                residual,
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class MiniCPMForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+
+    def __init__(
+        self,
+        config,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.num_experts = getattr(self.config, "num_experts", 0)
+        self.linear_method = linear_method
+        self.model = MiniCPMModel(config,
+                                  linear_method,
+                                  lora_config=lora_config)
+        unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            unpadded_vocab_size += lora_config.lora_extra_vocab_size
+        if not self.config.tie_word_embeddings:
+            self.lm_head = ParallelLMHead(
+                unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+            )
+        self.scale_width = self.config.hidden_size / self.config.dim_model_base
+
+        self.logits_processor = LogitsProcessor(unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        hidden_states = hidden_states / self.scale_width
+        if self.config.tie_word_embeddings:
+            lm_head_weight = self.model.embed_tokens.weight
+        else:
+            lm_head_weight = self.lm_head.weight
+        logits = self.logits_processor(lm_head_weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        expert_params_mapping = [
+            # (param_name, weight_name, expert_id)
+            ("ws" if weight_name in ["w1", "w3"] else "w2s",
+             f"experts.{expert_id}.{weight_name}.weight", expert_id)
+            for expert_id in range(self.num_experts)
+            for weight_name in ["w1", "w2", "w3"]
+        ]
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                for param_name, weight_name, expert_id in expert_params_mapping:
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 121 - 199
aphrodite/modeling/models/mixtral.py

@@ -22,72 +22,35 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
-import numpy as np
 import torch
 from torch import nn
 from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import LoRAConfig
-from aphrodite.modeling.layers.fused_moe import fused_topk
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.utils import print_warning_once
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear,
-    QKVParallelLinear, ReplicatedLinear, RowParallelLinear,
-    UnquantizedLinearMethod)
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.quantization.fp8 import (Fp8LinearMethod,
+                                                        per_tensor_quantize)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (tensor_model_parallel_all_reduce)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
-
-
-class MixtralMLP(nn.Module):
-
-    def __init__(
-        self,
-        num_experts: int,
-        hidden_size: int,
-        intermediate_size: int,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
-        super().__init__()
-        self.num_experts = num_experts
-        self.ffn_dim = intermediate_size
-        self.hidden_dim = hidden_size
-
-        self.w1 = ReplicatedLinear(self.hidden_dim,
-                                   self.ffn_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
-        self.w2 = ReplicatedLinear(self.ffn_dim,
-                                   self.hidden_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
-        self.w3 = ReplicatedLinear(self.hidden_dim,
-                                   self.ffn_dim,
-                                   bias=False,
-                                   linear_method=linear_method)
-
-        # TODO: Use aphrodite's SiluAndMul
-        self.act_fn = nn.SiLU()
-
-    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
-        w1_out, _ = self.w1(hidden_states)
-        w1_out = self.act_fn(w1_out)
-        w3_out, _ = self.w3(hidden_states)
-        current_hidden_states = w1_out * w3_out
-        current_hidden_states, _ = self.w2(current_hidden_states)
-        return current_hidden_states
+from aphrodite.modeling.utils import set_weight_attrs
 
 
 class MixtralMoE(nn.Module):
@@ -105,94 +68,101 @@ class MixtralMoE(nn.Module):
         top_k: int,
         hidden_size: int,
         intermediate_size: int,
+        params_dtype: Optional[torch.dtype] = None,
         tp_size: Optional[int] = None,
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        self.rank = get_tensor_model_parallel_rank()
         self.tp_size = tp_size or get_tensor_model_parallel_world_size()
         self.num_total_experts = num_experts
         self.top_k = top_k
         self.hidden_size = hidden_size
         self.intermediate_size = intermediate_size // self.tp_size
-        self.linear_method = linear_method
-        if self.linear_method is None:
-            self.linear_method = UnquantizedLinearMethod()
+        # FIXME(pcmoritz): Make this more general to support different
+        # quantization schemes
+        self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
+
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+        self.params_dtype = params_dtype
 
         self.gate = ReplicatedLinear(self.hidden_size,
                                      self.num_total_experts,
                                      bias=False,
+                                     params_dtype=self.params_dtype,
                                      linear_method=None)
 
-        if not isinstance(
-                self.linear_method, UnquantizedLinearMethod
-        ) and not self.linear_method.quant_config.support_fused_moe():
-            if self.tp_size > self.num_total_experts:
-                raise ValueError(
-                    f"Tensor parallel size {self.tp_size} is greater than "
-                    f"the number of experts {self.num_total_experts}.")
-            # Split experts equally between ranks
-            self.expert_indicies = np.array_split(
-                range(self.num_total_experts),
-                self.tp_size)[self.rank].tolist()
-            if not self.expert_indicies:
-                raise ValueError(
-                    f"Rank {self.rank} has no experts assigned to it.")
-
-            self.experts = nn.ModuleList([
-                MixtralMLP(self.num_total_experts,
-                           hidden_size,
-                           intermediate_size,
-                           linear_method=linear_method)
-                if idx in self.expert_indicies else None
-                for idx in range(self.num_total_experts)
-            ])
-        else:
-            self.ws = MergedColumnParallelLinear(hidden_size,
-                                                 [intermediate_size] * 2,
-                                                 bias=False,
-                                                 linear_method=linear_method,
-                                                 num_experts=num_experts)
-            self.w2s = RowParallelLinear(intermediate_size,
-                                         hidden_size,
-                                         bias=False,
-                                         linear_method=linear_method,
-                                         num_experts=num_experts)
+        self.ws = nn.Parameter(
+            torch.empty(self.num_total_experts,
+                        2 * self.intermediate_size,
+                        self.hidden_size,
+                        device="cuda",
+                        dtype=self.params_dtype))
+        self.w2s = nn.Parameter(
+            torch.empty(self.num_total_experts,
+                        self.hidden_size,
+                        self.intermediate_size,
+                        device="cuda",
+                        dtype=self.params_dtype))
+
+        # Scaling factors for FP8 weights
+        self.ws_scale = nn.Parameter(
+            torch.ones(
+                self.num_total_experts, device="cuda", dtype=torch.float32),
+            requires_grad=False) if self.use_fp8 else None
+        self.w2s_scale = nn.Parameter(
+            torch.ones(
+                self.num_total_experts, device="cuda", dtype=torch.float32),
+            requires_grad=False) if self.use_fp8 else None
+
+        set_weight_attrs(self.ws, {
+            "weight_loader": self.weight_loader,
+        })
+        set_weight_attrs(self.w2s, {
+            "weight_loader": self.weight_loader,
+        })
+
+    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
+                      weight_name: str, expert_id: int):
+        tp_rank = get_tensor_model_parallel_rank()
+        param_data = param.data
+        shard_size = self.intermediate_size
+        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+        if weight_name.endswith("w1.weight"):
+            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
+        if weight_name.endswith("w3.weight"):
+            param_data[expert_id,
+                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
+        if weight_name.endswith("w2.weight"):
+            param_data[expert_id, :, :] = loaded_weight[:, shard]
+
+    def process_weights_after_loading(self):
+        if self.use_fp8:
+            ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
+            w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
+            for expert in range(self.num_total_experts):
+                ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
+                    self.ws.data[expert, :, :])
+                w2s[expert, :, :], self.w2s_scale[
+                    expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
+            self.ws = nn.Parameter(ws, requires_grad=False)
+            self.w2s = nn.Parameter(w2s, requires_grad=False)
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_size = hidden_states.shape
         hidden_states = hidden_states.view(-1, self.hidden_size)
         # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-
-        if not isinstance(
-                self.linear_method, UnquantizedLinearMethod
-        ) and not self.linear_method.quant_config.support_fused_moe():
-            routing_weights, selected_experts = fused_topk(router_logits,
-                                                           self.top_k,
-                                                           renormalize=True)
-            final_hidden_states = None
-            for expert_idx in self.expert_indicies:
-                expert_layer = self.experts[expert_idx]
-                expert_mask = (selected_experts == expert_idx)
-                expert_weights = (routing_weights * expert_mask).sum(
-                    dim=-1, keepdim=True)
-
-                current_hidden_states = expert_layer(hidden_states).mul_(
-                    expert_weights)
-                if final_hidden_states is None:
-                    final_hidden_states = current_hidden_states
-                else:
-                    final_hidden_states.add_(current_hidden_states)
-        else:
-            final_hidden_states = self.linear_method.apply_moe_weights(
-                self.ws.linear_weights,
-                self.w2s.linear_weights,
-                hidden_states,
-                router_logits,
-                self.top_k,
-                renormalize=True,
-            )
+        final_hidden_states = fused_moe(hidden_states,
+                                        self.ws,
+                                        self.w2s,
+                                        router_logits,
+                                        self.top_k,
+                                        renormalize=True,
+                                        inplace=True,
+                                        use_fp8=self.use_fp8,
+                                        w1_scale=self.ws_scale,
+                                        w2_scale=self.w2s_scale)
 
         if self.tp_size > 1:
             final_hidden_states = tensor_model_parallel_all_reduce(
@@ -234,49 +204,33 @@ class MixtralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(hidden_size,
-                                               self.total_num_heads *
-                                               self.head_dim,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(hidden_size,
-                                               self.total_num_kv_heads *
-                                               self.head_dim,
-                                               bias=False,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(hidden_size,
-                                               self.total_num_kv_heads *
-                                               self.head_dim,
-                                               bias=False,
-                                               linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=False,
-                linear_method=linear_method,
+        if isinstance(linear_method, Fp8LinearMethod):
+            print_warning_once(
+                "For Mixtral FP8 quantization, we currently do not quantize "
+                "the attention layers until their FP8 performance is improved."
             )
+            linear_method = None
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
             linear_method=linear_method,
         )
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position,
             base=int(self.rope_theta),
-            is_neox_style=is_neox_style,
+            is_neox_style=True,
         )
         self.attn = Attention(
             self.num_heads,
@@ -293,14 +247,8 @@ class MixtralAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
@@ -384,7 +332,6 @@ class MixtralModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             self.vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
@@ -412,6 +359,8 @@ class MixtralModel(nn.Module):
 
 
 class MixtralForCausalLM(nn.Module):
+    fall_back_to_pt_during_load = False
+
     packed_modules_mapping = {
         "qkv_proj": [
             "q_proj",
@@ -451,7 +400,6 @@ class MixtralForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(
             self.unpadded_vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
             org_num_embeddings=config.vocab_size,
             padding_size=DEFAULT_VOCAB_PADDING_SIZE
             # We need bigger padding if using lora for kernel
@@ -459,7 +407,7 @@ class MixtralForCausalLM(nn.Module):
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
-                                                config.tokenizer_vocab_size)
+                                                config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -475,7 +423,7 @@ class MixtralForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -487,38 +435,24 @@ class MixtralForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
 
         expert_params_mapping = [
-            # (param_name, weight_name, shard_id, expert_id)
+            # (param_name, weight_name, expert_id)
             ("ws" if weight_name in ["w1", "w3"] else "w2s",
-             f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
+             f"experts.{expert_id}.{weight_name}.weight", expert_id)
             for expert_id in range(self.config.num_local_experts)
-            for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)]
-        ] if self.linear_method is None or (
-            self.linear_method.quant_config.support_fused_moe()) else []
+            for weight_name in ["w1", "w2", "w3"]
+        ]
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path,
-                cache_dir,
-                load_format,
-                revision,
-                self.config,
-                fall_back_to_pt=False):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
 
@@ -534,33 +468,21 @@ class MixtralForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                for (param_name, weight_name, shard_id,
-                     expert_id) in expert_params_mapping:
+                for param_name, weight_name, expert_id in expert_params_mapping:
                     if weight_name not in name:
                         continue
                     name = name.replace(weight_name, param_name)
-                    if name.endswith(".bias") and name not in params_dict:
-                        continue
                     param = params_dict[name]
                     weight_loader = param.weight_loader
-                    if shard_id is None:
-                        weight_loader(param,
-                                      loaded_weight,
-                                      expert_id=expert_id)
-                    else:
-                        weight_loader(param,
-                                      loaded_weight,
-                                      shard_id,
-                                      expert_id=expert_id)
+                    weight_loader(param,
+                                  loaded_weight,
+                                  weight_name,
+                                  expert_id=expert_id)
                     break
                 else:
                     # Skip loading extra bias for GPTQ models.
                     if name.endswith(".bias") and name not in params_dict:
                         continue
-                    # Skip experts that are not assigned to this worker.
-                    if ("block_sparse_moe.experts." in name
-                            and name not in params_dict):
-                        continue
                     param = params_dict[name]
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)

+ 404 - 0
aphrodite/modeling/models/mixtral_quant.py

@@ -0,0 +1,404 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Mixtral model."""
+from typing import Iterable, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import MixtralConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class MixtralMLP(nn.Module):
+
+    def __init__(
+        self,
+        num_experts: int,
+        hidden_size: int,
+        intermediate_size: int,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.num_experts = num_experts
+        self.ffn_dim = intermediate_size
+        self.hidden_dim = hidden_size
+
+        self.w1 = ReplicatedLinear(self.hidden_dim,
+                                   self.ffn_dim,
+                                   bias=False,
+                                   linear_method=linear_method)
+        self.w2 = ReplicatedLinear(self.ffn_dim,
+                                   self.hidden_dim,
+                                   bias=False,
+                                   linear_method=linear_method)
+        self.w3 = ReplicatedLinear(self.hidden_dim,
+                                   self.ffn_dim,
+                                   bias=False,
+                                   linear_method=linear_method)
+
+        # TODO: Use Aphrodite's SiluAndMul
+        self.act_fn = nn.SiLU()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        w1_out, _ = self.w1(hidden_states)
+        w1_out = self.act_fn(w1_out)
+        w3_out, _ = self.w3(hidden_states)
+        current_hidden_states = w1_out * w3_out
+        current_hidden_states, _ = self.w2(current_hidden_states)
+        return current_hidden_states
+
+
+class MixtralMoE(nn.Module):
+
+    def __init__(
+        self,
+        config: MixtralConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.rank = get_tensor_model_parallel_rank()
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.num_total_experts = config.num_local_experts
+        self.top_k = config.num_experts_per_tok
+        if self.tp_size > self.num_total_experts:
+            raise ValueError(
+                f"Tensor parallel size {self.tp_size} is greater than "
+                f"the number of experts {self.num_total_experts}.")
+        # Split experts equally between ranks
+        self.expert_indicies = np.array_split(range(
+            self.num_total_experts), self.tp_size)[self.rank].tolist()
+        if not self.expert_indicies:
+            raise ValueError(
+                f"Rank {self.rank} has no experts assigned to it.")
+
+        self.experts = nn.ModuleList([
+            MixtralMLP(self.num_total_experts,
+                       config.hidden_size,
+                       config.intermediate_size,
+                       linear_method=linear_method)
+            if idx in self.expert_indicies else None
+            for idx in range(self.num_total_experts)
+        ])
+        self.gate = ReplicatedLinear(config.hidden_size,
+                                     self.num_total_experts,
+                                     bias=False,
+                                     linear_method=None)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        num_tokens, hidden_dim = hidden_states.shape
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (num_tokens, n_experts)
+        router_logits, _ = self.gate(hidden_states)
+
+        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+        routing_weights, selected_experts = torch.topk(routing_weights,
+                                                       self.top_k,
+                                                       dim=-1)
+        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+
+        final_hidden_states = None
+        for expert_idx in self.expert_indicies:
+            expert_layer = self.experts[expert_idx]
+            expert_mask = (selected_experts == expert_idx)
+            expert_weights = (routing_weights * expert_mask).sum(dim=-1,
+                                                                 keepdim=True)
+
+            current_hidden_states = expert_layer(hidden_states).mul_(
+                expert_weights)
+            if final_hidden_states is None:
+                final_hidden_states = current_hidden_states
+            else:
+                final_hidden_states.add_(current_hidden_states)
+
+        return tensor_model_parallel_all_reduce(final_hidden_states).view(
+            num_tokens, hidden_dim)
+
+
+class MixtralAttention(nn.Module):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_heads: int,
+                 num_kv_heads: int,
+                 max_position: int = 4096 * 32,
+                 rope_theta: float = 10000,
+                 linear_method: Optional[LinearMethodBase] = None,
+                 sliding_window: Optional[int] = None) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.sliding_window = sliding_window
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position,
+            base=int(self.rope_theta),
+            is_neox_style=True,
+        )
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            sliding_window=self.sliding_window,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class MixtralDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: MixtralConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        # Requires transformers > 4.32.0
+        rope_theta = getattr(config, "rope_theta", 10000)
+        self.self_attn = MixtralAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            max_position=config.max_position_embeddings,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            sliding_window=config.sliding_window,
+            linear_method=linear_method)
+        self.block_sparse_moe = MixtralMoE(config=config,
+                                           linear_method=linear_method)
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.block_sparse_moe(hidden_states)
+        return hidden_states, residual
+
+
+class MixtralModel(nn.Module):
+
+    def __init__(
+        self,
+        config: MixtralConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
+        self.layers = nn.ModuleList([
+            MixtralDecoderLayer(config, linear_method=linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(positions, hidden_states,
+                                            kv_caches[i], attn_metadata,
+                                            residual)
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class MixtralForCausalLM(nn.Module):
+    fall_back_to_pt_during_load = False
+
+    def __init__(
+        self,
+        config: MixtralConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = MixtralModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+        ]
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Skip experts that are not assigned to this worker.
+                if ("block_sparse_moe.experts." in name
+                        and name not in params_dict):
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 16 - 34
aphrodite/modeling/models/mpt.py

@@ -1,12 +1,15 @@
 # coding=utf-8
 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
 import math
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
@@ -14,14 +17,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs.mpt import MPTConfig
 
 
@@ -208,9 +207,10 @@ class MPTModel(nn.Module):
         assert config.embedding_fraction == 1.0
         assert config.norm_type == "low_precision_layernorm"
 
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          config.d_model,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(
+            config.vocab_size,
+            config.d_model,
+        )
         self.blocks = nn.ModuleList(
             [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
         self.norm_f = nn.LayerNorm(config.d_model)
@@ -254,12 +254,8 @@ class MPTForCausalLM(nn.Module):
         self.linear_method = linear_method
 
         self.transformer = MPTModel(config, linear_method)
-        # self.lm_head_weight = self.transformer.wte.weight
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.transformer.wte.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -275,7 +271,7 @@ class MPTForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -287,26 +283,12 @@ class MPTForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             # Skip loading extra bias for GPTQ models.
             if name.endswith(".bias") and name not in params_dict:
                 continue
-            if "wte" in name:
-                # Copy word embedding to lm_head
-                head_name = name.replace("transformer.wte", "lm_head")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)

+ 55 - 86
aphrodite/modeling/models/olmo.py

@@ -37,47 +37,29 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """Inference-only OLMo model compatible with HuggingFace weights."""
-
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
-import torch.nn.functional as F
+# this model must need this dependency
+from hf_olmo import OLMoConfig
 from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
-from aphrodite.transformers_utils.configs.olmo import OLMoConfig
-
-
-class SwiGLU(nn.Module):
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x, gate = x.chunk(2, dim=-1)
-        return F.silu(gate) * x
-
-    @property
-    def output_multiplier(self) -> float:
-        return 0.5
 
 
 class OlmoAttention(nn.Module):
@@ -181,17 +163,16 @@ class OlmoMLP(nn.Module):
                                     bias=False)
 
         # Feed-forward input projection.
-        self.ff_proj = ColumnParallelLinear(
+        self.ff_proj = MergedColumnParallelLinear(
             config.d_model,
-            self.hidden_size,
+            [self.hidden_size // 2] * 2,
             bias=config.include_bias,
             linear_method=linear_method,
         )
 
         # Activation function.
-        # self.act = SiluAndMul()
-        # self.act.output_multiplier = 0.5
-        self.act = SwiGLU()
+        self.act = SiluAndMul()
+        self.act.output_multiplier = 0.5
         assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
 
         # Feed-forward output projection.
@@ -220,16 +201,14 @@ class OlmoMLP(nn.Module):
 
 class OlmoBlock(nn.Module):
     """
-    This is a typical transformer block where the output is computed as
-    ``MLP(LN(x + Attention(LN(x))))``
+    This is a typical transformer block where the output is
+    computed as ``MLP(LN(x + Attention(LN(x))))``
     (plus another skip connection).
     """
 
-    def __init__(
-        self,
-        config: OLMoConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         # Attention block.
         self.attn = OlmoAttention(config, linear_method)
@@ -256,11 +235,9 @@ class OlmoBlock(nn.Module):
 
 class OlmoModel(nn.Module):
 
-    def __init__(
-        self,
-        config: OLMoConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.config = config
 
@@ -269,17 +246,10 @@ class OlmoModel(nn.Module):
                 wte=VocabParallelEmbedding(
                     config.embedding_size or config.vocab_size,
                     config.d_model,
-                    linear_method=linear_method,
                 ),
                 ln_f=nn.LayerNorm(config.d_model,
                                   elementwise_affine=False,
                                   bias=False),
-                ff_out=ParallelLMHead(
-                    config.embedding_size or config.vocab_size,
-                    config.d_model,
-                    bias=config.include_bias,
-                    linear_method=linear_method,
-                ),
             ))
 
         blocks = [
@@ -290,6 +260,17 @@ class OlmoModel(nn.Module):
         else:
             self.transformer.update({"blocks": nn.ModuleList(blocks)})
 
+        if not config.weight_tying:
+            self.transformer.update({
+                "ff_out":
+                ColumnParallelLinear(
+                    config.d_model,
+                    config.embedding_size or config.vocab_size,
+                    bias=config.include_bias,
+                    linear_method=linear_method,
+                )
+            })
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -325,17 +306,17 @@ class OLMoForCausalLM(nn.Module):
     Extremely barebones HF model wrapper.
     """
 
-    def __init__(
-        self,
-        config: OLMoConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: OLMoConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.model = OlmoModel(config, linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = (self.model.transformer.wte.weight
+                               if config.weight_tying else
+                               self.model.transformer.ff_out.weight)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -355,8 +336,8 @@ class OLMoForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.model.transformer.ff_out,
-                                       hidden_states, sampling_metadata)
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+                                       sampling_metadata)
         return logits
 
     def sample(
@@ -367,31 +348,19 @@ class OLMoForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
-            if "wte" in name and self.config.weight_tying:
-                # Copy word embedding to lm_head
-                head_name = name.replace("model.transformer.wte",
-                                         "model.transformer.ff_out")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
+        for name, loaded_weight in weights:
             # attention
             if ".att" in name:
                 name = name.replace(".att", ".attn.att")
             # mlp
-            if ".ff" in name and "transformer.ff_out" not in name:
-                name = name.replace(".ff", ".mlp.ff")
+            if ".ff_proj" in name:
+                name = name.replace(".ff_proj", ".mlp.ff_proj")
+                # Reverse the weight for the MergeColumnParallelLinear
+                loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
+            if ".ff_out" in name and "transformer.ff_out" not in name:
+                name = name.replace(".ff_out", ".mlp.ff_out")
             # there is no bias in olmo
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",

+ 41 - 107
aphrodite/modeling/models/opt.py

@@ -18,36 +18,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OPT model compatible with HuggingFace weights."""
-
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import OPTConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    QKVParallelLinear,
-    ReplicatedLinear,
-    RowParallelLinear,
-)
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+from aphrodite.modeling.layers.vocab_parallel_embedding import \
+    VocabParallelEmbedding
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -82,30 +73,13 @@ class OPTAttention(nn.Module):
         self.head_dim = embed_dim // total_num_heads
         self.scaling = self.head_dim**-0.5
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(embed_dim,
-                                               embed_dim,
-                                               bias=bias,
-                                               linear_method=linear_method)
-            self.k_proj = ColumnParallelLinear(embed_dim,
-                                               embed_dim,
-                                               bias=bias,
-                                               linear_method=linear_method)
-            self.v_proj = ColumnParallelLinear(embed_dim,
-                                               embed_dim,
-                                               bias=bias,
-                                               linear_method=linear_method)
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                embed_dim,
-                self.head_dim,
-                total_num_heads,
-                bias=bias,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            embed_dim,
+            self.head_dim,
+            total_num_heads,
+            bias=bias,
+            linear_method=linear_method,
+        )
         self.out_proj = RowParallelLinear(
             embed_dim,
             embed_dim,
@@ -122,13 +96,8 @@ class OPTAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.chunk(chunks=3, dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.out_proj(attn_output)
         return output
@@ -154,8 +123,7 @@ class OPTDecoderLayer(nn.Module):
 
         self.self_attn_layer_norm = nn.LayerNorm(
             self.embed_dim,
-            elementwise_affine=config.layer_norm_elementwise_affine,
-        )
+            elementwise_affine=config.layer_norm_elementwise_affine)
         self.fc1 = ColumnParallelLinear(
             self.embed_dim,
             config.ffn_dim,
@@ -173,8 +141,7 @@ class OPTDecoderLayer(nn.Module):
         )
         self.final_layer_norm = nn.LayerNorm(
             self.embed_dim,
-            elementwise_affine=config.layer_norm_elementwise_affine,
-        )
+            elementwise_affine=config.layer_norm_elementwise_affine)
 
     def forward(
         self,
@@ -187,11 +154,9 @@ class OPTDecoderLayer(nn.Module):
         # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
         if self.do_layer_norm_before:
             hidden_states = self.self_attn_layer_norm(hidden_states)
-        hidden_states = self.self_attn(
-            hidden_states=hidden_states,
-            kv_cache=kv_cache,
-            attn_metadata=attn_metadata,
-        )
+        hidden_states = self.self_attn(hidden_states=hidden_states,
+                                       kv_cache=kv_cache,
+                                       attn_metadata=attn_metadata)
         hidden_states = residual + hidden_states
         # 350m applies layer norm AFTER attention
         if not self.do_layer_norm_before:
@@ -228,7 +193,6 @@ class OPTDecoder(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.word_embed_proj_dim,
-            linear_method=linear_method,
         )
         # Positional embeddings are replicated (not sharded).
         self.embed_positions = OPTLearnedPositionalEmbedding(
@@ -236,22 +200,18 @@ class OPTDecoder(nn.Module):
 
         # Project out & in will be replicated if they exist.
         if config.word_embed_proj_dim != config.hidden_size:
-            self.project_out = ReplicatedLinear(
-                config.hidden_size,
-                config.word_embed_proj_dim,
-                bias=False,
-                linear_method=linear_method,
-            )
+            self.project_out = ReplicatedLinear(config.hidden_size,
+                                                config.word_embed_proj_dim,
+                                                bias=False,
+                                                linear_method=linear_method)
         else:
             self.project_out = None
 
         if config.word_embed_proj_dim != config.hidden_size:
-            self.project_in = ReplicatedLinear(
-                config.word_embed_proj_dim,
-                config.hidden_size,
-                bias=False,
-                linear_method=linear_method,
-            )
+            self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
+                                               config.hidden_size,
+                                               bias=False,
+                                               linear_method=linear_method)
         else:
             self.project_in = None
 
@@ -262,8 +222,7 @@ class OPTDecoder(nn.Module):
         if config.do_layer_norm_before and not config._remove_final_layer_norm:
             self.final_layer_norm = nn.LayerNorm(
                 config.hidden_size,
-                elementwise_affine=config.layer_norm_elementwise_affine,
-            )
+                elementwise_affine=config.layer_norm_elementwise_affine)
         else:
             self.final_layer_norm = None
 
@@ -327,11 +286,8 @@ class OPTForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = OPTModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head_weight = self.model.decoder.embed_tokens.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -347,7 +303,7 @@ class OPTForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -359,43 +315,21 @@ class OPTForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
-            if "lm_head" in name and name not in params_dict:
+        for name, loaded_weight in weights:
+            if "lm_head.weight" in name:
                 continue
-            if "embed_tokens" in name:
-                # Copy word embedding to lm_head
-                if name.startswith("decoder."):
-                    name = "model." + name
-                head_name = name.replace("model.decoder.embed_tokens",
-                                         "lm_head")
-                if head_name in params_dict:
-                    lm_head_param = params_dict[head_name]
-                    weight_loader = getattr(lm_head_param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(lm_head_param, loaded_weight)
             if name.startswith("decoder."):
                 name = "model." + name
 
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 318 - 0
aphrodite/modeling/models/orion.py

@@ -0,0 +1,318 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
+# Copyright (c) OrionStar Inc.
+# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
+"""Inference-only Orion-14B model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class OrionMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class OrionAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            linear_method=linear_method,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class OrionDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.self_attn = OrionAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+        )
+        self.mlp = OrionMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.rms_norm_eps)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+        return hidden_states, None
+
+
+class OrionModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
+        self.layers = nn.ModuleList([
+            OrionDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+                residual,
+            )
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class OrionForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = OrionModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 45 - 112
aphrodite/modeling/models/phi.py

@@ -35,46 +35,35 @@
 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Inference-only Phi model compatible with HuggingFace weights."""
-
-from typing import List, Optional
+"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.linear import (
-    ColumnParallelLinear,
-    LinearMethodBase,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class PhiAttention(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
         self.hidden_size = config.hidden_size
@@ -87,36 +76,13 @@ class PhiAttention(nn.Module):
                           tensor_model_parallel_world_size)
 
         # pylint: disable=C0103
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.hidden_size,
-                bias=True,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.hidden_size,
-                bias=True,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.hidden_size,
-                bias=True,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                self.hidden_size,
-                self.head_size,
-                self.total_num_heads,
-                bias=True,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            self.hidden_size,
+            self.head_size,
+            self.total_num_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
         self.dense = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
@@ -133,15 +99,11 @@ class PhiAttention(nn.Module):
         # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
         rope_theta = 10000
         max_position_embeddings = getattr(config, "n_positions", 2048)
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
-            is_neox_style=is_neox_style,
         )
         self.attn = Attention(self.num_heads, self.head_size, scaling)
 
@@ -152,13 +114,8 @@ class PhiAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.chunk(chunks=3, dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
         q, k = self.rotary_emb(position_ids, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.dense(attn_output)
@@ -167,11 +124,9 @@ class PhiAttention(nn.Module):
 
 class PhiMLP(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
 
         n_inner = getattr(config, "n_inner", None)
@@ -199,11 +154,9 @@ class PhiMLP(nn.Module):
 
 class PhiLayer(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
@@ -232,17 +185,14 @@ class PhiLayer(nn.Module):
 
 class PhiModel(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+                                                   config.hidden_size)
         self.layers = nn.ModuleList([
             PhiLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -274,25 +224,19 @@ class PhiModel(nn.Module):
 
 class PhiForCausalLM(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ):
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
 
         self.model = PhiModel(config, linear_method)
 
-        self.lm_head = ParallelLMHead(
-            config.vocab_size,
-            config.hidden_size,
-            bias=True,
-            linear_method=linear_method,
-        )
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      bias=True)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -309,7 +253,7 @@ class PhiForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata, self.lm_head.bias)
         return logits
 
@@ -321,31 +265,20 @@ class PhiForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
-            ("qkv_proj", "v_proj", "v"),
+            ("qkv_proj", "v_proj", "v")
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
 
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 43 - 101
aphrodite/modeling/models/qwen.py

@@ -4,38 +4,28 @@
 # Copyright (c) Alibaba Cloud.
 # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
 """Inference-only QWen model compatible with HuggingFace weights."""
-
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
+from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-    ColumnParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
-from aphrodite.transformers_utils.configs.qwen import QWenConfig
 
 
 class QWenMLP(nn.Module):
@@ -48,47 +38,21 @@ class QWenMLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.w2 = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.w1 = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.c_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.c_proj = RowParallelLinear(intermediate_size,
+                                        hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.w1(x)
-            gate, _ = self.w2(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.c_proj(x)
         return x
@@ -107,8 +71,8 @@ class QWenAttention(nn.Module):
     ):
         super().__init__()
         self.hidden_size = hidden_size
-        tensor_model_parallel_world_size = (
-            get_tensor_model_parallel_world_size())
+        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
+        )
         self.total_num_heads = num_heads
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
         self.num_heads = (self.total_num_heads //
@@ -129,16 +93,12 @@ class QWenAttention(nn.Module):
         )
         self.scaling = self.head_dim**-0.5
 
-        is_neox_style = (True if linear_method is None
-                         or linear_method.quant_config.rope_style() is None
-                         else linear_method.quant_config.rope_style())
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
             max_position=max_position_embeddings,
             base=rope_theta,
             rope_scaling=rope_scaling,
-            is_neox_style=is_neox_style,
         )
         self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
 
@@ -153,7 +113,6 @@ class QWenAttention(nn.Module):
         q, k, v = qkv.chunk(chunks=3, dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
-
         output, _ = self.c_proj(attn_output)
         return output
 
@@ -162,7 +121,7 @@ class QWenBlock(nn.Module):
 
     def __init__(
         self,
-        config: QWenConfig,
+        config: PretrainedConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
@@ -170,22 +129,18 @@ class QWenBlock(nn.Module):
 
         rope_theta = getattr(config, "rope_theta", 10000)
         rope_scaling = getattr(config, "rope_scaling", None)
-        self.attn = QWenAttention(
-            config.hidden_size,
-            config.num_attention_heads,
-            config.max_position_embeddings,
-            rope_theta=rope_theta,
-            rope_scaling=rope_scaling,
-            linear_method=linear_method,
-        )
+        self.attn = QWenAttention(config.hidden_size,
+                                  config.num_attention_heads,
+                                  config.max_position_embeddings,
+                                  rope_theta=rope_theta,
+                                  rope_scaling=rope_scaling,
+                                  linear_method=linear_method)
 
         self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
 
-        self.mlp = QWenMLP(
-            config.hidden_size,
-            config.intermediate_size // 2,
-            linear_method=linear_method,
-        )
+        self.mlp = QWenMLP(config.hidden_size,
+                           config.intermediate_size // 2,
+                           linear_method=linear_method)
 
     def forward(
         self,
@@ -218,16 +173,17 @@ class QWenModel(nn.Module):
 
     def __init__(
         self,
-        config: QWenConfig,
+        config: PretrainedConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
         self.config = config
         self.vocab_size = config.vocab_size
 
-        self.wte = VocabParallelEmbedding(config.vocab_size,
-                                          config.hidden_size,
-                                          linear_method=linear_method)
+        self.wte = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.h = nn.ModuleList([
             QWenBlock(config, linear_method)
             for _ in range(config.num_hidden_layers)
@@ -260,18 +216,15 @@ class QWenLMHeadModel(nn.Module):
 
     def __init__(
         self,
-        config: QWenConfig,
+        config: PretrainedConfig,
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.transformer = QWenModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -287,7 +240,7 @@ class QWenLMHeadModel(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -299,28 +252,17 @@ class QWenLMHeadModel(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w2", 0),
             ("gate_up_proj", "w1", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 83 - 136
aphrodite/modeling/models/qwen2.py

@@ -1,8 +1,8 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2024 The Qwen team.
+# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -23,38 +23,29 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2 model compatible with HuggingFace weights."""
-
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import Qwen2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    ColumnParallelLinear,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class Qwen2MLP(nn.Module):
@@ -67,47 +58,21 @@ class Qwen2MLP(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ) -> None:
         super().__init__()
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                hidden_size,
-                intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                hidden_size,
-                [intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
-        self.down_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-        )
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -115,17 +80,15 @@ class Qwen2MLP(nn.Module):
 
 class Qwen2Attention(nn.Module):
 
-    def __init__(
-        self,
-        hidden_size: int,
-        num_heads: int,
-        num_kv_heads: int,
-        max_position: int = 4096 * 32,
-        rope_theta: float = 10000,
-        use_sliding_window: bool = False,
-        linear_method: Optional[LinearMethodBase] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
+    def __init__(self,
+                 hidden_size: int,
+                 num_heads: int,
+                 num_kv_heads: int,
+                 max_position: int = 4096 * 32,
+                 rope_theta: float = 10000,
+                 use_sliding_window: bool = False,
+                 linear_method: Optional[LinearMethodBase] = None,
+                 sliding_window: Optional[int] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -149,37 +112,14 @@ class Qwen2Attention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window if use_sliding_window else None
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_heads * self.head_dim,
-                bias=True,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=True,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=True,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_kv_heads,
-                bias=True,
-                linear_method=linear_method,
-            )
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=True,
+            linear_method=linear_method,
+        )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
@@ -193,13 +133,11 @@ class Qwen2Attention(nn.Module):
             max_position=max_position,
             base=self.rope_theta,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-            sliding_window=self.sliding_window,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=self.sliding_window)
 
     def forward(
         self,
@@ -208,14 +146,8 @@ class Qwen2Attention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
@@ -244,8 +176,7 @@ class Qwen2DecoderLayer(nn.Module):
             rope_theta=rope_theta,
             use_sliding_window=use_sliding_window,
             linear_method=linear_method,
-            sliding_window=config.sliding_window,
-        )
+            sliding_window=config.sliding_window)
         self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
@@ -301,7 +232,6 @@ class Qwen2Model(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
-            linear_method=linear_method,
         )
         self.layers = nn.ModuleList([
             Qwen2DecoderLayer(config, layer_idx, linear_method)
@@ -332,23 +262,48 @@ class Qwen2Model(nn.Module):
 
 
 class Qwen2ForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+    ]
+    embedding_modules = {}
+    embedding_padding_modules = []
 
     def __init__(
         self,
         config: Qwen2Config,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
+        del lora_config
         super().__init__()
         self.config = config
         self.linear_method = linear_method
         self.model = Qwen2Model(config, linear_method)
-        self.lm_head = ParallelLMHead(
-            config.vocab_size,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+
+        if config.tie_word_embeddings:
+            self.lm_head_weight = self.model.embed_tokens.weight
+        else:
+            self.lm_head = ParallelLMHead(config.vocab_size,
+                                          config.hidden_size)
+            self.lm_head_weight = self.lm_head.weight
+
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -364,7 +319,7 @@ class Qwen2ForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -376,13 +331,7 @@ class Qwen2ForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -391,15 +340,13 @@ class Qwen2ForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            if self.config.tie_word_embeddings and "lm_head.weight" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 12 - 21
aphrodite/modeling/models/qwen2_moe.py

@@ -23,7 +23,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import torch
 import torch.nn.functional as F
@@ -31,6 +31,10 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -44,13 +48,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
-from aphrodite.distributed import (tensor_model_parallel_all_reduce)
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class Qwen2MoeMLP(nn.Module):
@@ -367,6 +366,8 @@ class Qwen2MoeModel(nn.Module):
 
 class Qwen2MoeForCausalLM(nn.Module):
 
+    fall_back_to_pt_during_load = False
+
     def __init__(
         self,
         config: PretrainedConfig,
@@ -377,8 +378,7 @@ class Qwen2MoeForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.model = Qwen2MoeModel(config, linear_method)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -394,7 +394,7 @@ class Qwen2MoeForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -406,11 +406,7 @@ class Qwen2MoeForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -421,12 +417,7 @@ class Qwen2MoeForCausalLM(nn.Module):
         ]
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path,
-                cache_dir,
-                load_format,
-                revision,
-                fall_back_to_pt=False):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 63 - 150
aphrodite/modeling/models/stablelm.py

@@ -17,87 +17,51 @@
 # This code is based off the following work:
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
-"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model
-compatible with HuggingFace weights."""
-
-from typing import List, Optional, Tuple
+"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
+model compatible with HuggingFace weights."""
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-    ColumnParallelLinear,
-)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.distributed import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
-from aphrodite.common.sequence import SamplerOutput
 
 
 class StablelmMLP(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
         self.intermediate_size = config.intermediate_size
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.gate_proj = ColumnParallelLinear(
-                config.hidden_size,
-                config.intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-            self.up_proj = ColumnParallelLinear(
-                config.hidden_size,
-                config.intermediate_size,
-                bias=False,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.gate_up_proj = MergedColumnParallelLinear(
-                config.hidden_size,
-                [config.intermediate_size] * 2,
-                bias=False,
-                linear_method=linear_method,
-            )
+        self.gate_up_proj = MergedColumnParallelLinear(
+            config.hidden_size, [config.intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
         self.down_proj = RowParallelLinear(config.intermediate_size,
                                            config.hidden_size,
                                            bias=False)
         self.act_fn = SiluAndMul()
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        if self.merge_weight:
-            gate_up, _ = self.gate_up_proj(x)
-        else:
-            up, _ = self.up_proj(x)
-            gate, _ = self.gate_proj(x)
-            gate_up = torch.cat([gate, up], dim=-1)
+        gate_up, _ = self.gate_up_proj(x)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -105,11 +69,9 @@ class StablelmMLP(nn.Module):
 
 class StablelmAttention(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -130,68 +92,38 @@ class StablelmAttention(nn.Module):
             1, self.total_num_key_value_heads // tp_size)
         self.head_dim = self.hidden_size // self.total_num_heads
         self.max_position_embeddings = config.max_position_embeddings
-        rope_pct = self.config.partial_rotary_factor
+        rope_pct = getattr(config, "rope_pct",
+                           getattr(config, "partial_rotary_factor", 1))
         self.rotary_ndims = int(self.head_dim * rope_pct)
         self.scaling = self.head_dim**-0.5
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_key_value_heads * self.head_dim
         self.qkv_bias = getattr(config, "use_qkv_bias", False)
         if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
-            raise ValueError("hidden_size must be divisible by num_heads (got "
-                             f"`hidden_size`: {self.hidden_size}"
+            raise ValueError(f"hidden_size must be divisible by num_heads "
+                             f"(got `hidden_size`: {self.hidden_size}"
                              f" and `num_heads`: {self.num_heads}).")
 
-        if (linear_method is not None
-                and not linear_method.quant_config.merge_weight()):
-            self.merge_weight = False
-            self.q_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_heads * self.head_dim,
-                bias=self.qkv_bias,
-                linear_method=linear_method,
-            )
-            self.k_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=self.qkv_bias,
-                linear_method=linear_method,
-            )
-            self.v_proj = ColumnParallelLinear(
-                self.hidden_size,
-                self.total_num_kv_heads * self.head_dim,
-                bias=self.qkv_bias,
-                linear_method=linear_method,
-            )
-        else:
-            self.merge_weight = True
-            self.qkv_proj = QKVParallelLinear(
-                self.hidden_size,
-                self.head_dim,
-                self.total_num_heads,
-                self.total_num_key_value_heads,
-                self.qkv_bias,
-                linear_method=linear_method,
-            )
-        self.o_proj = RowParallelLinear(
-            self.total_num_heads * self.head_dim,
-            self.hidden_size,
-            bias=False,
-            linear_method=linear_method,
-        )
-        self.rotary_ndims = int(self.head_dim *
-                                self.config.partial_rotary_factor)
+        self.qkv_proj = QKVParallelLinear(self.hidden_size,
+                                          self.head_dim,
+                                          self.total_num_heads,
+                                          self.total_num_key_value_heads,
+                                          self.qkv_bias,
+                                          linear_method=linear_method)
+        self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+                                        self.hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method)
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.rotary_ndims,
             max_position=self.config.max_position_embeddings,
             base=self.config.rope_theta,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_key_value_heads,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_key_value_heads)
 
     def forward(
         self,
@@ -200,14 +132,8 @@ class StablelmAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        if self.merge_weight:
-            qkv, _ = self.qkv_proj(hidden_states)
-            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
-                                dim=-1)
-        else:
-            q, _ = self.q_proj(hidden_states)
-            k, _ = self.k_proj(hidden_states)
-            v, _ = self.v_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
@@ -224,10 +150,11 @@ class StablelmDecoderLayer(nn.Module):
         super().__init__()
         self.self_attn = StablelmAttention(config)
         self.mlp = StablelmMLP(config, linear_method)
-        self.input_layernorm = nn.LayerNorm(config.hidden_size,
-                                            eps=config.layer_norm_eps)
+        norm_eps = getattr(config, "norm_eps",
+                           getattr(config, "layer_norm_eps", 1e-05))
+        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
-                                                     eps=config.layer_norm_eps)
+                                                     eps=norm_eps)
 
     def forward(
         self,
@@ -258,20 +185,21 @@ class StablelmDecoderLayer(nn.Module):
 
 class StableLMEpochModel(nn.Module):
 
-    def __init__(
-        self,
-        config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
-    ) -> None:
+    def __init__(self,
+                 config: PretrainedConfig,
+                 linear_method: Optional[LinearMethodBase] = None) -> None:
         super().__init__()
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   linear_method=linear_method)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
         self.layers = nn.ModuleList([
             StablelmDecoderLayer(config, linear_method)
             for _ in range(config.num_hidden_layers)
         ])
-        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        norm_eps = getattr(config, "norm_eps",
+                           getattr(config, "layer_norm_eps", 1e-05))
+        self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
 
     def forward(
         self,
@@ -283,7 +211,6 @@ class StableLMEpochModel(nn.Module):
         hidden_states = self.embed_tokens(input_ids)
         for i in range(len(self.layers)):
             layer = self.layers[i]
-            # pylint: disable=unused-variable
             hidden_states, residual = layer(
                 positions,
                 hidden_states,
@@ -305,11 +232,8 @@ class StablelmForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.model = StableLMEpochModel(config, linear_method)
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      linear_method=linear_method)
-        self.logits_processor = LogitsProcessor(config.vocab_size,
-                                                config.tokenizer_vocab_size)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
     def forward(
@@ -325,7 +249,7 @@ class StablelmForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.lm_head, hidden_states,
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                        sampling_metadata)
         return logits
 
@@ -337,13 +261,7 @@ class StablelmForCausalLM(nn.Module):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -352,13 +270,8 @@ class StablelmForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision,
-                self.config):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name
@@ -366,7 +279,7 @@ class StablelmForCausalLM(nn.Module):
                 # Models trained using ColossalAI may include these tensors in
                 # the checkpoint. Skip them.
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)

+ 303 - 0
aphrodite/modeling/models/starcoder2.py

@@ -0,0 +1,303 @@
+# coding=utf-8
+# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Starcoder2 model."""
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import Starcoder2Config
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              LinearMethodBase,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class Starcoder2Attention(nn.Module):
+
+    def __init__(self,
+                 config: Starcoder2Config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+
+        self.hidden_size = config.hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = config.num_key_value_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = self.hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = config.rope_theta
+        self.max_position_embeddings = config.max_position_embeddings
+        self.use_bias = config.use_bias
+        self.sliding_window = config.sliding_window
+
+        self.qkv_proj = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=self.use_bias,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            self.hidden_size,
+            bias=self.use_bias,
+            linear_method=linear_method,
+        )
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position_embeddings,
+            base=int(self.rope_theta),
+            is_neox_style=True,
+        )
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+            sliding_window=self.sliding_window,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class Starcoder2MLP(nn.Module):
+
+    def __init__(self,
+                 config: Starcoder2Config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.c_fc = ColumnParallelLinear(
+            config.hidden_size,
+            config.intermediate_size,
+            bias=config.use_bias,
+            linear_method=linear_method,
+        )
+        self.c_proj = RowParallelLinear(
+            config.intermediate_size,
+            config.hidden_size,
+            bias=config.use_bias,
+            linear_method=linear_method,
+        )
+        quant_config = getattr(linear_method, "quant_config", None)
+        self.act = get_act_fn(config.hidden_act, quant_config,
+                              config.intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.c_proj(hidden_states)
+        return hidden_states
+
+
+class Starcoder2DecoderLayer(nn.Module):
+
+    def __init__(self,
+                 config: Starcoder2Config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = Starcoder2Attention(config,
+                                             linear_method=linear_method)
+        self.mlp = Starcoder2MLP(config, linear_method=linear_method)
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.norm_epsilon)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        return hidden_states
+
+
+class Starcoder2Model(nn.Module):
+
+    def __init__(self,
+                 config: Starcoder2Config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        # TODO: consider padding_idx (currently removed)
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size)
+        self.layers = nn.ModuleList([
+            Starcoder2DecoderLayer(config, linear_method=linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states = layer(positions, hidden_states, kv_caches[i],
+                                  attn_metadata)
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+class Starcoder2ForCausalLM(nn.Module):
+
+    def __init__(self,
+                 config: Starcoder2Config,
+                 linear_method: Optional[LinearMethodBase] = None):
+        super().__init__()
+        self.config = config
+        self.model = Starcoder2Model(config, linear_method=linear_method)
+        self.vocab_size = config.vocab_size
+        self.unpadded_vocab_size = config.vocab_size
+        if config.tie_word_embeddings:
+            self.lm_head_weight = self.model.embed_tokens.weight
+        else:
+            self.unpadded_vocab_size = config.vocab_size
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+            )
+            self.lm_head_weight = self.lm_head.weight
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head_weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+        ]
+
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                if self.config.tie_word_embeddings and "lm_head.weight" in name:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 365 - 0
aphrodite/modeling/models/xverse.py

@@ -0,0 +1,365 @@
+# coding=utf-8
+# Adapted from
+# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Xverse model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class XverseMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class XverseAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 8192,
+        linear_method: Optional[LinearMethodBase] = None,
+        bias: bool = False,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        # partition the KV heads across multiple tensor parallel GPUs.
+        assert self.total_num_kv_heads % tp_size == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=bias,
+            linear_method=linear_method,
+        )
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=bias,
+            linear_method=linear_method,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              sliding_window=sliding_window)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class XverseDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        sliding_window = getattr(config, "sliding_window", None)
+        self.self_attn = XverseAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=getattr(config, "num_key_value_heads",
+                                 config.num_attention_heads),
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            linear_method=linear_method,
+            bias=getattr(config, "bias", False),
+            sliding_window=sliding_window,
+        )
+        self.mlp = XverseMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            linear_method=linear_method,
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class XverseModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+        self.embed_tokens = VocabParallelEmbedding(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+        )
+        self.layers = nn.ModuleList([
+            XverseDecoderLayer(config, linear_method)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+                residual,
+            )
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class XverseForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "gate_up_proj",
+        "down_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embed_tokens": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+        lora_config=None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = XverseModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if ("rotary_emb.inv_freq" in name
+                    or "rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 3 - 4
aphrodite/modeling/sampling_metadata.py

@@ -1,14 +1,13 @@
-from dataclasses import dataclass
-from typing import Dict, List, Tuple, Optional, TypeVar, Callable
 import random
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, TypeVar
 
 import torch
 
-from aphrodite.modeling.layers.ops.sample import (get_num_triton_sampler_splits
-                                                  )
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SequenceData
 from aphrodite.common.utils import is_pin_memory_available
+from aphrodite.modeling.layers.ops.sample import get_num_triton_sampler_splits
 
 _SEED_0_REPLACEMENT = 3403598558  # chosen by fair roll of a die
 

+ 26 - 5
aphrodite/processing/block/block_table.py

@@ -1,18 +1,17 @@
 from typing import List, Optional
 
-from aphrodite.processing.block.interfaces import (
-    Block,
-    DeviceAwareBlockAllocator,
-)
+from aphrodite.processing.block.interfaces import Block, DeviceAwareBlockAllocator
 from aphrodite.common.utils import Device, cdiv, chunk_list
 
 
 class BlockTable:
     """A class to manage blocks for a specific sequence.
+
     The BlockTable maps a sequence of tokens to a list of blocks, where each
-    block represents a contiguous memory allocation for a portion of the
+    block represents a contiguous memory allocation for a portion of the 
     sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
     responsible for allocating and freeing memory for the blocks.
+
     Args:
         block_size (int): The maximum number of tokens that can be stored in a
             single block.
@@ -21,6 +20,7 @@ class BlockTable:
         _blocks (Optional[List[Block]], optional): An optional list of existing
             blocks to initialize the BlockTable with. If not provided, an empty
             BlockTable is created.
+
     Attributes:
         _block_size (int): The maximum number of tokens that can be stored in a
             single block.
@@ -50,12 +50,15 @@ class BlockTable:
     def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
         """Calculates the minimum number of blocks required to store a given
         sequence of token IDs.
+
         This assumes worst-case scenario, where every block requires a new
         allocation (e.g. ignoring prefix caching).
+
         Args:
             token_ids (List[int]): The sequence of token IDs to be stored.
             block_size (int): The maximum number of tokens that can be stored in
                 a single block.
+
         Returns:
             int: The minimum number of blocks required to store the given
                 sequence of token IDs.
@@ -66,8 +69,10 @@ class BlockTable:
                  token_ids: List[int],
                  device: Device = Device.GPU) -> None:
         """Allocates memory blocks for storing the given sequence of token IDs.
+
         This method allocates the required number of blocks to store the given
         sequence of token IDs.
+
         Args:
             token_ids (List[int]): The sequence of token IDs to be stored.
             device (Device, optional): The device on which the blocks should be
@@ -85,17 +90,21 @@ class BlockTable:
                          num_lookahead_slots: int = 0) -> None:
         """Appends a sequence of token IDs to the existing blocks in the
         BlockTable.
+
         This method appends the given sequence of token IDs to the existing
         blocks in the BlockTable. If there is not enough space in the existing
         blocks, new blocks are allocated using the `ensure_num_empty_slots`
         method to accommodate the additional tokens.
+
         The token IDs are divided into chunks of size `block_size` (except for
         the first chunk, which may be smaller), and each chunk is appended to a
         separate block.
+
         Args:
             token_ids (List[int]): The sequence of token IDs to be appended.
         """
         assert self._is_allocated
+        assert self._blocks is not None
 
         self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
                                     num_lookahead_slots)
@@ -111,10 +120,12 @@ class BlockTable:
     def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
         """Ensures that the BlockTable has at least the specified number of
         empty slots available.
+
         This method checks if the BlockTable has enough empty slots (i.e.,
         available space) to accommodate the requested number of tokens. If not,
         it allocates additional blocks on the GPU to ensure that the required
         number of empty slots is available.
+
         Args:
             num_empty_slots (int): The minimum number of empty slots required.
         """
@@ -137,10 +148,12 @@ class BlockTable:
     def fork(self) -> "BlockTable":
         """Creates a new BlockTable instance with a copy of the blocks from the
         current instance.
+
         This method creates a new BlockTable instance with the same block size,
         block allocator, and a copy of the blocks from the current instance. The
         new BlockTable has its own independent set of blocks, but shares the
         same underlying memory allocation with the original BlockTable.
+
         Returns:
             BlockTable: A new BlockTable instance with a copy of the blocks from
                 the current instance.
@@ -155,6 +168,7 @@ class BlockTable:
 
     def free(self) -> None:
         """Frees the memory occupied by the blocks in the BlockTable.
+
         This method iterates over all the blocks in the `_blocks` list and calls
         the `free` method of the `_allocator` object to release the memory
         occupied by each block. After freeing all the blocks, the `_blocks` list
@@ -169,10 +183,12 @@ class BlockTable:
     def physical_block_ids(self) -> List[int]:
         """Returns a list of physical block indices for the blocks in the
         BlockTable.
+
         This property returns a list of integers, where each integer represents
         the physical block index of a corresponding block in the `_blocks` list.
         The physical block index is a unique identifier for the memory location
         occupied by the block.
+
         Returns:
             List[int]: A list of physical block indices for the blocks in the
                 BlockTable.
@@ -182,11 +198,14 @@ class BlockTable:
 
     def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
         """Get the number of "unseen" tokens in the sequence.
+
         Unseen tokens are tokens in the sequence corresponding to this block
         table, but are not yet appended to this block table.
+
         Args:
             sequence_token_ids (List[int]): The list of token ids in the
                 sequence.
+
         Returns:
             List[int]: The postfix of sequence_token_ids that has not yet been
                 appended to the block table.
@@ -239,6 +258,7 @@ class BlockTable:
     def num_full_slots(self) -> int:
         """Returns the total number of tokens currently stored in the
         BlockTable.
+
         Returns:
             int: The total number of tokens currently stored in the BlockTable.
         """
@@ -248,6 +268,7 @@ class BlockTable:
             self, token_ids: List[int], num_lookahead_slots: int) -> int:
         """Determine how many blocks will be "touched" by appending the token
         ids.
+
         This is required for the scheduler to determine whether a sequence can
         continue generation, or if it must be preempted.
         """

+ 18 - 2
aphrodite/processing/block/common.py

@@ -9,9 +9,11 @@ RefCount = int
 
 class RefCounter:
     """A class for managing reference counts for a set of block indices.
+
     The RefCounter class maintains a dictionary that maps block indices to their
     corresponding reference counts. It provides methods to increment, decrement,
     and retrieve the reference count for a given block index.
+
     Args:
         all_block_indices (Iterable[BlockId]): An iterable of block indices
             to initialize the reference counter with.
@@ -54,9 +56,11 @@ class RefCounter:
 
 class ReadOnlyRefCounter:
     """A read-only view of the RefCounter class.
+
     The ReadOnlyRefCounter class provides a read-only interface to access the
     reference counts maintained by a RefCounter instance. It does not allow
     modifications to the reference counts.
+
     Args:
         refcounter (RefCounter): The RefCounter instance to create a read-only
             view for.
@@ -77,10 +81,12 @@ class ReadOnlyRefCounter:
 
 class CopyOnWriteTracker:
     """A class for tracking and managing copy-on-write operations for blocks.
+
     The CopyOnWriteTracker class maintains a mapping of source block indices to
         their corresponding copy-on-write destination block indices. It works in
         conjunction with a RefCounter and a BlockAllocator to handle reference
         counting and block allocation.
+
     Args:
         refcounter (RefCounter): The reference counter used to track block
             reference counts.
@@ -93,20 +99,23 @@ class CopyOnWriteTracker:
         refcounter: RefCounter,
         allocator: BlockAllocator,
     ):
-        self._copy_on_writes = defaultdict(list)
+        self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
         self._refcounter = refcounter
         self._allocator = allocator
 
     def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
         """Performs a copy-on-write operation on the given block if it is not
         appendable.
+
         This method checks the reference count of the given block. If the
         reference count is greater than 1, indicating that the block is shared,
         a copy-on-write operation is performed. The original block is freed,
         and a new block is allocated with the same content. The new block index
         is returned.
+
         Args:
             block (Block): The block to check for copy-on-write.
+
         Returns:
             Optional[BlockId]: The block index of the new block if a copy-on
                 -write operation was performed, or the original block index if
@@ -129,6 +138,8 @@ class CopyOnWriteTracker:
                 prev_block=block.prev_block).block_id
 
             # Track src/dst copy.
+            assert src_block_id is not None
+            assert block_id is not None
             self._copy_on_writes[src_block_id].append(block_id)
 
         return block_id
@@ -136,9 +147,11 @@ class CopyOnWriteTracker:
     def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
         """Clears the copy-on-write tracking information and returns the current
         state.
+
         This method returns a dictionary mapping source block indices to lists
         of destination block indices for the current copy-on-write operations.
         It then clears the internal tracking information.
+
         Returns:
             Dict[BlockId, List[BlockId]]: A dictionary mapping source
                 block indices to lists of destination block indices for the
@@ -151,11 +164,14 @@ class CopyOnWriteTracker:
 
 def get_all_blocks_recursively(last_block: Block) -> List[Block]:
     """Retrieves all the blocks in a sequence starting from the last block.
+
     This function recursively traverses the sequence of blocks in reverse order,
     starting from the given last block, and returns a list of all the blocks in
     the sequence.
+
     Args:
         last_block (Block): The last block in the sequence.
+
     Returns:
         List[Block]: A list of all the blocks in the sequence, in the order they
             appear.
@@ -166,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
             recurse(block.prev_block, lst)
         lst.append(block)
 
-    all_blocks = []
+    all_blocks: List[Block] = []
     recurse(last_block, all_blocks)
     return all_blocks

+ 22 - 11
aphrodite/processing/block/cpu_gpu_block_allocator.py

@@ -1,24 +1,21 @@
 from typing import Dict, List, Optional
 
-from aphrodite.processing.block.interfaces import (
-    Block,
-    BlockAllocator,
-    DeviceAwareBlockAllocator,
-)
-from aphrodite.processing.block.naive_block import (
-    NaiveBlock,
-    NaiveBlockAllocator,
-)
-from aphrodite.processing.block.prefix_caching_block import (
-    PrefixCachingBlockAllocator, )
 from aphrodite.common.utils import Device
+from aphrodite.processing.block.interfaces import (Block, BlockAllocator,
+                                                   DeviceAwareBlockAllocator)
+from aphrodite.processing.block.naive_block import (NaiveBlock,
+                                                    NaiveBlockAllocator)
+from aphrodite.processing.block.prefix_caching_block import \
+    PrefixCachingBlockAllocator
 
 
 class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     """A block allocator that can allocate blocks on both CPU and GPU memory.
+
     This class implements the `DeviceAwareBlockAllocator` interface and provides
     functionality for allocating and managing blocks of memory on both CPU and
     GPU devices.
+
     The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
     blocks, and allows for allocation, deallocation, forking, and swapping of
     blocks across these memory pools.
@@ -33,10 +30,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     ) -> DeviceAwareBlockAllocator:
         """Creates a CpuGpuBlockAllocator instance with the specified
         configuration.
+
         This static method creates and returns a CpuGpuBlockAllocator instance
         based on the provided parameters. It initializes the CPU and GPU block
         allocators with the specified number of blocks, block size, and
         allocator type.
+
         Args:
             allocator_type (str): The type of block allocator to use for CPU
                 and GPU blocks. Currently supported values are "naive" and
@@ -46,9 +45,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
             num_cpu_blocks (int): The number of blocks to allocate for CPU
                 memory.
             block_size (int): The size of each block in number of tokens.
+
         Returns:
             DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
                 specified configuration.
+
         Notes:
             - The block IDs are assigned contiguously, with GPU block IDs coming
                 before CPU block IDs.
@@ -114,10 +115,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     def allocate_mutable(self, prev_block: Optional[Block],
                          device: Device) -> Block:
         """Allocates a new mutable block on the specified device.
+
         Args:
             prev_block (Optional[Block]): The previous block to in the sequence.
                 Used for prefix hashing.
             device (Device): The device on which to allocate the new block.
+
         Returns:
             Block: The newly allocated mutable block.
         """
@@ -127,12 +130,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
                            token_ids: List[int], device: Device) -> Block:
         """Allocates a new immutable block with the provided token IDs on the
         specified device.
+
         Args:
             prev_block (Optional[Block]): The previous block in the sequence.
                 Used for prefix hashing.
             token_ids (List[int]): The list of token IDs to be stored in the new
                 block.
             device (Device): The device on which to allocate the new block.
+
         Returns:
             Block: The newly allocated immutable block containing the provided
                 token IDs.
@@ -142,6 +147,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
 
     def free(self, block: Block) -> None:
         """Frees the memory occupied by the given block.
+
         Args:
             block (Block): The block to be freed.
         """
@@ -151,8 +157,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     def fork(self, last_block: Block) -> List[Block]:
         """Creates a new sequence of blocks that shares the same underlying
             memory as the original sequence.
+
         Args:
             last_block (Block): The last block in the original sequence.
+
         Returns:
             List[Block]: A new list of blocks that shares the same memory as the
                 original sequence.
@@ -162,9 +170,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
 
     def get_num_free_blocks(self, device: Device) -> int:
         """Returns the number of free blocks available on the specified device.
+
         Args:
             device (Device): The device for which to query the number of free
                 blocks.
+
         Returns:
             int: The number of free blocks available on the specified device.
         """
@@ -173,6 +183,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
     def clear_copy_on_writes(self) -> Dict[int, List[int]]:
         """Clears the copy-on-write (CoW) state and returns the mapping of
             source to destination block IDs.
+
         Returns:
             Dict[int, List[int]]: A dictionary mapping source block IDs to lists
                 of destination block IDs.

Деякі файли не було показано, через те що забагато файлів було змінено