瀏覽代碼

[v0.5.3] Release Candidate (#388)

* add new logits processor

* clean up sampler

* stop workflows on dev

* pipe in logitproc in lora

* compute logits in model_runner

* simplify sampler in llama

* support command-r+ model

Co-authored-by: o_0 <pkircher82@gmail.com>

* logitproc for cohere

* add some imports

* add logit scale for command-r

* fix gptq for cohere

* conflicts with _is_neuron()

* fix query shape in logits processor

* fix pydantic serializer warning

* add warning for mismatch in vocab size

* fix quants for gemma

* do not remove duplicate params for qwen2

* refactor: neuron support

* fix and re-enable custom all-reduce

* add scheduler delay factor

* add reorder scheduler policy

* fix logprobs serializer warnings

* improve detokenization performance; improve logprobs

* LockFile -> SoftLockFile

* fix tied embeddings in falcon

* fix query shape in moe models

* rope: get_device() -> device

* fix formatting

* feat: attention refactor part 2

* min_tokens

* add logprob ranks

* don't output two stop strings in api

* vision model support

* yapf

* optimize logprob ranks

* add stop_reason

* ipv6 fix

* KVCache type in llava

* add python nccl wrapper, remove cupy

* ruff

* support gemma 1.1 models with approximate gelu

* fix nccl path

* add dbrx support

* formatting

* add v2 block manager

* ruff

* add qwen2moe support (needs transformers git)

* Chunked Prefill Part 2: data update

* optional vision language config for neuron

* enable multi-node inference

* add exllamav2 tensor paralell, fused MoE for GPTQ/AWQ

* formatting and typing

* this is kinda dumb if you ask me

* logprobs fixes

* do not use codespell on kernels

* fix codespell path

* rccl path for ROCm

* fix case when API request to top_k is 0

* cache tokenizer len

* small fixes (#393)

* Pin torch to 2.2.0

* Improve nccl errors.

* directly use  in forward pass

* CMake build system (#395)

* add cmake

* update setup.py

* clean up + pin versions

* add hadamard kernels for compilation

* formatting

* fix requirement

* restore nvcc

* dev libraries for cuda

* Fix cohere for command-r+ (#394)

* Fix cohere for command-r+

* fix smoothquant+

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>

* proper typing

* fix build for 7.5

* feat: add optimized layernorm kernels (#398)

* simplify tokenizer.py

* Speculative Decoding Part 4: Lookahead scheduling (#402)

Co-authored-by: Cade Daniel <edacih@gmail.com>

* feat: Intel CPU support (#403)

* add CPU types

Co-authored-by: bigPYJ1151 <jiang1.li@intel.com>

* port all relevant kernels to CPU

Co-authored-by: bigPYJ1151 <jiang1.li@intel.com>

* make cpu build work

* add SDPA backend

* working CPU backend

* make this work on the API

* remove unnecessary file

---------

Co-authored-by: bigPYJ1151 <jiang1.li@intel.com>

* fix minor cuda version mismatch with runtime

* fix spec_decode and block imports

* add speculative config and arg for later

* better recognize cpu build

* add dict merging util

* refactor scheduler for chunked prefill, remove reorder policy for now

* feat: FP8 E4M3 KV Cache (#405)

* Add FP8 E4M3 kernels for AMD GPUs

Courtesy of the AMD team.
Co-authored-by: Adrian Abeyta <adabeyta@amd.com>

* refactor kv cache support in attention kernel

* ops and pybind

* fix compilation errors

* update cmake

* support fp8 e4m3 in the engine

* fix AMD build

* do not compile exl2 for amd (for now)

* amd compatibility fixes for align block size

* maybe fix layernorm compiles

* fix moe softmax kernel

* fix multi-gpu ray tokenizer for trust_remote_code

* enable hf_transfer if installed

* fix CPU build

* make detokenization optional

* update torch to 2.2.1

* make nccl wrapper more robust

* split requirements

* add sampling param for left-truncating prompt tokens

* fix: Improve cohere model. (#404)

* feat: add chunked prefill scheduler (#406)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>

* allow out-of-tree model registry

* use the get_len() method instead of manual len calculation

* fix TP for llava

* enable attention bias support in llama

* disable new layernorm kernels for CUDA < 12.0

* separate init_distributed_environment from worker

* feat: Triton flash attention backend for ROCm (#407)

* refactor executor classes and workers

* KeyError for GPT-NeoX

* add triton flash-attention kernel for ROCm

* fix types in merge_dict

* skip rows in logits added for the prompt tokens

* head_size 256 for gemma in triton FA

* no key sorting for outlines

* fix case where hf_config==None

* move megatron to a top-level directory

* fix micromamba url

* fix outlines requirements

* why was that not committed?

* fully working chunked prefill

* fix docstrings

* roll back chunked prefill changes to SDPA, isolate cpu worker

* make init_distributed_environment compatible with init_process_group

* fix echo

* fix formatting for previous commit

* fix stop strings not being excluded from outputs

* move merge_async_iterators to common utils

* fix type hint

* enable custom_all_reduce by default in llm.py

* triton compile error for flash_attn

* debug logging for distributed_init_method

* incorrect use of monotonic time in metrics logger

* fix neuron

* cache the p2p access check for memory saving

* feat: EETQ quantization (#408)

* Support arbitrary model in GGUF. (#381)

* Support arbitrary model in GGUF.

* Update gguf to support cohere and dbrx.

* yapf

* formatting

* fix: Allow setting config-path when converting ggufs. (#410)

* incorrect comparison for hadamard and punica checks

* fix: max_num_batched_tokens for chunked_prefill (#412)

* Fix max_num_batched_tokens for chunked_prefill.

* Remove accelerate as dependency.

* add args to benchmark script

* change chunk size to 768 default

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>

* feat: upport twe lm_head for quantized weights (#409)

* Support twe lm_head for quantized weights.

* Remove unused class.

* separate api server args into another file

* add CLI app

* fix: split the exl2 weight loading and SQ+ init (#423)

* Split the exl2 weight ASAP.

* Shard exl2 weights on cpu. reducing vram bubbles.

* feat: support sharded ggufs (#420)

* Support sharded ggufs.

* Fix .gguf ext detection.

* log tokenizer conversion message once

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>

* chore: port sampler+metadata changes from main to dev (#427)

* Port sampler/metadata changes over from main

* Update old logitprocs to support serial processing

* fix fine-grained seeding. Again.

* formatting

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>

* suppress import error for eetq

* fix CPU blocks logger for CPU backend

* simplify model_executor logic

* fix the nsight profiling with ray

* fix engine_use_ray=True

* feat: LM Format Enforcer support (#428)

* feat: add lm-format-enforcer support for guided decoding

Co-authored-by: Noam Gat <noamgat@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>

* clean up

---------

Co-authored-by: Noam Gat <noamgat@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>

* fix: abort requests when the connection to /v1/completions is interrupted (#431)

* fix: linear bias of qkv layers in models (#430)

* feat: Speculative Decoding using a draft model (#432)

* feat: support speculative decoding using draft model

* benchmark script update; add attribution

Co-authored-by: Cade Daniel <edacih@gmail.com>

---------

Co-authored-by: Cade Daniel <edacih@gmail.com>

* feat: add ngram prompt lookup decoding for speculative decoding (#438)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>

* fix: rope scaling for cohere and qwen (#436)

* Fix incorrect auto RoPE of cohere.

* Remove vocab padding from cohere and qwen.

---------

Co-authored-by: AlpinDale <alpindale@gmail.com>

* fix: shard exl2 weights more evenly between ranks (#437)

Co-authored-by: AlpinDale <alpindale@gmail.com>

* chore: update Kobold Lite Embed (#433)

* Add files via upload

* Update api_server.py

Update Kobold Version

* fix: logging in the API server

* fix: options requests in the api (#439)

* fix: cpu executor

* fix serial impl of bias logproc

* tiny fixes (#441)

* Temporary fix of chat api.

* Fix indentation of bnb to be within _set_default_torch_dtype.

* Revert the is_neox_style change.

* Better handle tokenizer vocab size.

* vllm #4270

* vllm #4280

* fix: use BiasLogitsProcessor for OpenAI endpoint (#443)

Sterilize bias dict token range according to active model.

* fix: llama3 generations with EOS

* fix: rope_style in quants

* fix: rope style in eetq

* fix: incorrect module in quip

* fix: restore backwards compatibility with sm_60 (P100 and GP100) (#444)

* fix: compatibility with sm_60 arch

* fix min requirement for gguf

* fix: revision arg not being correctly passed (#456)

* chore: Fix minor bugs in outlines and lmfe. (#449)

* fix: lora errors (#462)

* fix: possibly unbreak loras

* fix lora logit processor

* increase support vocab size for lora to 128k

* 15k dim

* 43k dim

* compile less dtypes for punica

* fix compile error

* fix lm_head in lora

* Refactor: Quantization (#454)

* split quant ops

* isolate the quantization modeling code

* properly set the _quant_C module

* import error fixes

* raise importerror only when the specific quant is called

* missed one

* roll back cc for hadamard

* address comments

* raise error in __init__ class

* missed it in gptq

* ruff

* Bump `torch` to 2.3.0 (#467)

* bump to torch 2.3.0

* fix triton

* fix: Navi support (#466)

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>

* Dockerfile: permission update, configurable build jobs, torch 2.3.0 (#465)

* topk as linear write

* better variable naming

* yapf considers this space to be CRITICAL

* Overhauled SamplingTensors construction.
Fix multiple bugs in sampler flags.
Restored the functioning logitproc format.
Fixed sudden NaNs in quadratic smoothing.
Rewrote mirostat to work with seeds and other samplers.
Removed branches from some samplers.

* Fix logitproc for logit_bias in OAI endpoints.

* merge main

* fix memory pinning conditional

* Missed .items() and assert

* fix: kobold api /tokencount (#424)

* /tokencount uses mismatched value types

* reduce output to only 'value'

---------

Co-authored-by: Adrian Wells <neoturi@hotmail.com>

* fixes #458

* #453, #458

* revert irrelevant changes

* properly this time

---------

Co-authored-by: 50h100a <all_awful@proton.me>
Co-authored-by: 50h100a <136940546+50h100a@users.noreply.github.com>
Co-authored-by: Krovius <neoturi@gmail.com>
Co-authored-by: Adrian Wells <neoturi@hotmail.com>
Co-authored-by: AlpinDale <alpindale@gmail.com>

* ruff

* ruff again

* codespell

* yapf

* add sm_60 to build ci

* restore punica dtypes

* compile extra dtypes for punica

* bump torch in environment.yaml

* update ROCm dockerfile

* add TODO for build script

* bump version to 0.5.3

---------

Co-authored-by: o_0 <pkircher82@gmail.com>
Co-authored-by: sgsdxzy <sgsdxzy@gmail.com>
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: bigPYJ1151 <jiang1.li@intel.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: 50h100a <136940546+50h100a@users.noreply.github.com>
Co-authored-by: Noam Gat <noamgat@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Pyroserenus <142424797+Pyroserenus@users.noreply.github.com>
Co-authored-by: 50h100a <all_awful@proton.me>
Co-authored-by: Houman <houmie@gmail.com>
Co-authored-by: Naomiusearch <121130001+Naomiusearch@users.noreply.github.com>
Co-authored-by: The Objective Dad <63609026+theobjectivedad@users.noreply.github.com>
Co-authored-by: Krovius <neoturi@gmail.com>
Co-authored-by: Adrian Wells <neoturi@hotmail.com>
AlpinDale 10 月之前
父節點
當前提交
9d81716bfd
共有 100 個文件被更改,包括 8899 次插入3355 次删除
  1. 1 1
      .github/workflows/publish.yml
  2. 0 2
      .github/workflows/ruff.yml
  3. 3 2
      .github/workflows/scripts/build.sh
  4. 0 2
      .github/workflows/yapf.yml
  5. 395 0
      CMakeLists.txt
  6. 2 1
      MANIFEST.in
  7. 3 1
      aphrodite/__init__.py
  8. 15 0
      aphrodite/attention/__init__.py
  9. 0 0
      aphrodite/attention/backends/__init__.py
  10. 122 0
      aphrodite/attention/backends/abstract.py
  11. 266 0
      aphrodite/attention/backends/flash_attn.py
  12. 373 0
      aphrodite/attention/backends/rocm_flash_attn.py
  13. 254 0
      aphrodite/attention/backends/sdpa.py
  14. 396 0
      aphrodite/attention/backends/xformers.py
  15. 47 0
      aphrodite/attention/layer.py
  16. 0 0
      aphrodite/attention/ops/__init__.py
  17. 216 0
      aphrodite/attention/ops/paged_attn.py
  18. 0 0
      aphrodite/attention/ops/prefix_prefill.py
  19. 805 0
      aphrodite/attention/ops/triton_flash_attn.py
  20. 74 0
      aphrodite/attention/selector.py
  21. 429 60
      aphrodite/common/config.py
  22. 14 0
      aphrodite/common/logger.py
  23. 31 20
      aphrodite/common/logits_processor.py
  24. 17 11
      aphrodite/common/outputs.py
  25. 59 2
      aphrodite/common/sampling_params.py
  26. 150 3
      aphrodite/common/sequence.py
  27. 14 9
      aphrodite/common/test_utils.py
  28. 171 29
      aphrodite/common/utils.py
  29. 3 0
      aphrodite/distributed/__init__.py
  30. 17 9
      aphrodite/distributed/communication_op.py
  31. 0 0
      aphrodite/distributed/device_communicators/__init__.py
  32. 36 11
      aphrodite/distributed/device_communicators/custom_all_reduce.py
  33. 269 0
      aphrodite/distributed/device_communicators/pynccl.py
  34. 68 0
      aphrodite/distributed/device_communicators/pynccl_utils.py
  35. 91 22
      aphrodite/distributed/parallel_state.py
  36. 131 0
      aphrodite/distributed/utils.py
  37. 31 0
      aphrodite/endpoints/cli.py
  38. 100 33
      aphrodite/endpoints/kobold/klite.embd
  39. 21 5
      aphrodite/endpoints/llm.py
  40. 165 257
      aphrodite/endpoints/openai/api_server.py
  41. 122 0
      aphrodite/endpoints/openai/args.py
  42. 44 6
      aphrodite/endpoints/openai/protocol.py
  43. 40 30
      aphrodite/endpoints/openai/serving_chat.py
  44. 62 91
      aphrodite/endpoints/openai/serving_completions.py
  45. 71 42
      aphrodite/endpoints/openai/serving_engine.py
  46. 196 412
      aphrodite/engine/aphrodite_engine.py
  47. 188 41
      aphrodite/engine/args_tools.py
  48. 36 20
      aphrodite/engine/async_aphrodite.py
  49. 27 12
      aphrodite/engine/metrics.py
  50. 0 0
      aphrodite/engine/output_processor/__init__.py
  51. 68 0
      aphrodite/engine/output_processor/interfaces.py
  52. 122 0
      aphrodite/engine/output_processor/multi_step.py
  53. 272 0
      aphrodite/engine/output_processor/single_step.py
  54. 99 0
      aphrodite/engine/output_processor/stop_checker.py
  55. 16 0
      aphrodite/engine/output_processor/util.py
  56. 162 0
      aphrodite/executor/cpu_executor.py
  57. 47 15
      aphrodite/executor/executor_base.py
  58. 74 91
      aphrodite/executor/gpu_executor.py
  59. 27 33
      aphrodite/executor/neuron_executor.py
  60. 70 104
      aphrodite/executor/ray_gpu_executor.py
  61. 0 18
      aphrodite/executor/utils.py
  62. 180 54
      aphrodite/lora/layers.py
  63. 3 2
      aphrodite/lora/lora.py
  64. 79 58
      aphrodite/lora/models.py
  65. 4 0
      aphrodite/lora/punica.py
  66. 3 4
      aphrodite/lora/utils.py
  67. 16 11
      aphrodite/lora/worker_manager.py
  68. 0 4
      aphrodite/modeling/__init__.py
  69. 25 0
      aphrodite/modeling/guided_decoding/__init__.py
  70. 70 0
      aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py
  71. 70 0
      aphrodite/modeling/guided_decoding/lm_format_enforcer_logits_processors.py
  72. 3 3
      aphrodite/modeling/guided_decoding/outlines_decoding.py
  73. 58 45
      aphrodite/modeling/guided_decoding/outlines_logits_processors.py
  74. 153 92
      aphrodite/modeling/hf_downloader.py
  75. 6 4
      aphrodite/modeling/layers/activation.py
  76. 0 93
      aphrodite/modeling/layers/attention/__init__.py
  77. 0 138
      aphrodite/modeling/layers/attention/backends/flash_attn.py
  78. 0 315
      aphrodite/modeling/layers/attention/backends/xformers.py
  79. 0 139
      aphrodite/modeling/layers/attention/ops/paged_attn.py
  80. 3 5
      aphrodite/modeling/layers/fused_moe/__init__.py
  81. 49 33
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  82. 190 27
      aphrodite/modeling/layers/linear.py
  83. 113 0
      aphrodite/modeling/layers/logits_processor.py
  84. 0 36
      aphrodite/modeling/layers/quantization/__init__.py
  85. 19 22
      aphrodite/modeling/layers/rejection.py
  86. 2 2
      aphrodite/modeling/layers/rotary_embedding.py
  87. 318 247
      aphrodite/modeling/layers/sampler.py
  88. 41 11
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  89. 16 14
      aphrodite/modeling/loader.py
  90. 0 1
      aphrodite/modeling/megatron/README.md
  91. 0 127
      aphrodite/modeling/megatron/cupy_utils.py
  92. 0 49
      aphrodite/modeling/megatron/utils.py
  93. 0 97
      aphrodite/modeling/metadata.py
  94. 22 15
      aphrodite/modeling/models/__init__.py
  95. 27 29
      aphrodite/modeling/models/baichuan.py
  96. 30 32
      aphrodite/modeling/models/bloom.py
  97. 47 31
      aphrodite/modeling/models/chatglm.py
  98. 181 71
      aphrodite/modeling/models/cohere.py
  99. 423 0
      aphrodite/modeling/models/dbrx.py
  100. 216 149
      aphrodite/modeling/models/deepseek.py

+ 1 - 1
.github/workflows/publish.yml

@@ -49,7 +49,7 @@ jobs:
       matrix:
           os: ['ubuntu-20.04']
           python-version: ['3.8', '3.9', '3.10', '3.11']
-          pytorch-version: ['2.2.0']  # Must be the most recent version that meets requirements.txt.
+          pytorch-version: ['2.3.0']  # Must be the most recent version that meets requirements-cuda.txt.
           cuda-version: ['11.8', '12.1']
 
     steps:

+ 0 - 2
.github/workflows/ruff.yml

@@ -6,11 +6,9 @@ on:
   push:
     branches:
       - main
-      - dev
   pull_request:
     branches:
       - main
-      - dev
 
 jobs:
   ruff:

+ 3 - 2
.github/workflows/scripts/build.sh

@@ -8,12 +8,13 @@ PATH=${cuda_home}/bin:$PATH
 LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
 
 # Limit Ninja's number of processes to 1 to prevent OOM
+# TODO: Can we do 2 jobs?
 export MAX_JOBS=1
 
-export TORCH_CUDA_ARCH_LIST="6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
+export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
 # Install requirements
 $python_executable -m pip install wheel packaging
-$python_executable -m pip install -r requirements.txt
+$python_executable -m pip install -r requirements-cuda.txt
 
 # Build
 $python_executable setup.py bdist_wheel --dist-dir=dist

+ 0 - 2
.github/workflows/yapf.yml

@@ -6,11 +6,9 @@ on:
   push:
     branches:
       - main
-      - dev
   pull_request:
     branches:
       - main
-      - dev
 jobs:
   yapf:
     runs-on: ubuntu-latest

+ 395 - 0
CMakeLists.txt

@@ -0,0 +1,395 @@
+cmake_minimum_required(VERSION 3.21)
+
+project(aphrodite_extensions LANGUAGES CXX)
+
+option(APHRODITE_TARGET_DEVICE "Target device backend for Aphrodite" "cuda")
+
+message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
+message(STATUS "Target device: ${APHRODITE_TARGET_DEVICE}")
+
+include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
+
+#
+# Supported python versions.  These versions will be searched in order, the
+# first match will be selected.  These should be kept in sync with setup.py.
+#
+set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
+
+# Supported NVIDIA architectures.
+set(CUDA_SUPPORTED_ARCHS "6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0")
+
+# Supported AMD GPU architectures.
+set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
+
+#
+# Supported/expected torch versions for CUDA/ROCm.
+#
+# Currently, having an incorrect pytorch version results in a warning
+# rather than an error.
+#
+# Note: the CUDA torch version is derived from pyproject.toml and various
+# requirements.txt files and should be kept consistent.  The ROCm torch
+# versions are derived from Dockerfile.rocm
+#
+set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
+set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
+set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
+
+#
+# Try to find python package with an executable that exactly matches
+# `APHRODITE_PYTHON_EXECUTABLE` and is one of the supported versions.
+#
+if (APHRODITE_PYTHON_EXECUTABLE)
+  find_python_from_executable(${APHRODITE_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
+else()
+  message(FATAL_ERROR
+    "Please set APHRODITE_PYTHON_EXECUTABLE to the path of the desired python version"
+    " before running cmake configure.")
+endif()
+
+#
+# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
+#
+append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
+
+# Ensure the 'nvcc' command is in the PATH
+find_program(NVCC_EXECUTABLE nvcc)
+if (CUDA_FOUND AND NOT NVCC_EXECUTABLE)
+    message(FATAL_ERROR "nvcc not found")
+endif()
+
+#
+# Import torch cmake configuration.
+# Torch also imports CUDA (and partially HIP) languages with some customizations,
+# so there is no need to do this explicitly with check_language/enable_language,
+# etc.
+#
+find_package(Torch REQUIRED)
+
+#
+# Normally `torch.utils.cpp_extension.CUDAExtension` would add
+# `libtorch_python.so` for linking against an extension. Torch's cmake
+# configuration does not include this library (presumably since the cmake
+# config is used for standalone C++ binaries that link against torch).
+# The `libtorch_python.so` library defines some of the glue code between
+# torch/python via pybind and is required by APHRODITE extensions for this
+# reason. So, add it by manually using `append_torchlib_if_found` from
+# torch's cmake setup.
+#
+find_library(torch_python_LIBRARY torch_python PATHS
+  "${TORCH_INSTALL_PREFIX}/lib")
+
+#
+# Forward the non-CUDA device extensions to external CMake scripts.
+#
+if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
+    NOT APHRODITE_TARGET_DEVICE STREQUAL "rocm")
+    if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
+        include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
+    else()
+        message(FATAL_ERROR "Unsupported Aphrodite target device: ${APHRODITE_TARGET_DEVICE}")
+    endif()
+    return()
+endif()
+
+#
+# Set up GPU language and check the torch version and warn if it isn't
+# what is expected.
+#
+if (NOT HIP_FOUND AND CUDA_FOUND)
+  set(APHRODITE_GPU_LANG "CUDA")
+
+  if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
+    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
+      "expected for CUDA build, saw ${Torch_VERSION} instead.")
+  endif()
+elseif(HIP_FOUND)
+  set(APHRODITE_GPU_LANG "HIP")
+
+  # Importing torch recognizes and sets up some HIP/ROCm configuration but does
+  # not let cmake recognize .hip files. In order to get cmake to understand the
+  # .hip extension automatically, HIP must be enabled explicitly.
+  enable_language(HIP)
+
+  # ROCm 5.x
+  if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
+      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
+    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
+      "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
+  endif()
+
+  # ROCm 6.x
+  if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
+      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
+    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
+      "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
+  endif()
+else()
+  message(FATAL_ERROR "Can't find CUDA or HIP installation.")
+endif()
+
+#
+# Override the GPU architectures detected by cmake/torch and filter them by
+# the supported versions for the current language.
+# The final set of arches is stored in `APHRODITE_GPU_ARCHES`.
+#
+override_gpu_arches(APHRODITE_GPU_ARCHES
+  ${APHRODITE_GPU_LANG}
+  "${${APHRODITE_GPU_LANG}_SUPPORTED_ARCHS}")
+
+#
+# Query torch for additional GPU compilation flags for the given
+# `APHRODITE_GPU_LANG`.
+# The final set of arches is stored in `APHRODITE_GPU_FLAGS`.
+#
+get_torch_gpu_compiler_flags(APHRODITE_GPU_FLAGS ${APHRODITE_GPU_LANG})
+
+#
+# Set nvcc parallelism.
+#
+if(NVCC_THREADS AND APHRODITE_GPU_LANG STREQUAL "CUDA")
+  list(APPEND APHRODITE_GPU_FLAGS "--threads=${NVCC_THREADS}")
+endif()
+
+#
+# Define extension targets
+#
+
+#
+# _C extension
+#
+
+set(APHRODITE_EXT_SRC
+  "kernels/cache_kernels.cu"
+  "kernels/attention/attention_kernels.cu"
+  "kernels/pos_encoding_kernels.cu"
+  "kernels/activation_kernels.cu"
+  "kernels/layernorm_kernels.cu"
+  "kernels/cuda_utils_kernels.cu"
+  "kernels/moe/align_block_size_kernel.cu"
+  "kernels/pybind.cpp")
+
+if(APHRODITE_GPU_LANG STREQUAL "CUDA")
+  list(APPEND APHRODITE_EXT_SRC
+    "kernels/all_reduce/custom_all_reduce.cu")
+endif()
+
+define_gpu_extension_target(
+  _C
+  DESTINATION aphrodite
+  LANGUAGE ${APHRODITE_GPU_LANG}
+  SOURCES ${APHRODITE_EXT_SRC}
+  COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
+  ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  WITH_SOABI)
+
+
+#
+# _quant_C extension
+#
+
+set (APHRODITE_QUANT_EXT_SRC
+  "kernels/quantization/gptq/q_gemm.cu"
+  "kernels/quantization/squeezellm/quant_cuda_kernel.cu"
+  "kernels/quantization/exl2/q_matrix.cu"
+  "kernels/quantization/exl2/q_gemm_exl2.cu"
+  "kernels/quantization/quant_ops.cpp")
+
+if(APHRODITE_GPU_LANG STREQUAL "CUDA")
+  list(APPEND APHRODITE_QUANT_EXT_SRC
+    "kernels/quantization/aqlm/aqlm_cuda_entry.cpp"
+    "kernels/quantization/aqlm/aqlm_cuda_kernel.cu"
+    "kernels/quantization/awq/gemm_kernels.cu"
+    "kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu"
+    "kernels/quantization/bitsandbytes/format.cu"
+    "kernels/quantization/bitsandbytes/gemm_s4_f16.cu"
+    "kernels/quantization/gguf/gguf_kernel.cu"
+    "kernels/quantization/marlin/marlin_cuda_kernel.cu"
+    "kernels/quantization/quip/origin_order.cu")
+endif()
+
+define_gpu_extension_target(
+  _quant_C
+  DESTINATION aphrodite
+  LANGUAGE ${APHRODITE_GPU_LANG}
+  SOURCES ${APHRODITE_QUANT_EXT_SRC}
+  COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
+  ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  WITH_SOABI)
+
+#
+# _moe_C extension
+#
+
+set(APHRODITE_MOE_EXT_SRC
+  "kernels/moe/moe_ops.cpp"
+  "kernels/moe/softmax.cu")
+
+define_gpu_extension_target(
+  _moe_C
+  DESTINATION aphrodite
+  LANGUAGE ${APHRODITE_GPU_LANG}
+  SOURCES ${APHRODITE_MOE_EXT_SRC}
+  COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
+  ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  WITH_SOABI)
+
+#
+# _punica_C extension
+#
+
+set(APHRODITE_PUNICA_EXT_SRC
+  "kernels/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
+  "kernels/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
+  "kernels/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
+  "kernels/punica/punica_ops.cc")
+
+#
+# Copy GPU compilation flags+update for punica
+#
+set(APHRODITE_PUNICA_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
+list(REMOVE_ITEM APHRODITE_PUNICA_GPU_FLAGS
+  "-D__CUDA_NO_HALF_OPERATORS__"
+  "-D__CUDA_NO_HALF_CONVERSIONS__"
+  "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
+  "-D__CUDA_NO_HALF2_OPERATORS__")
+
+#
+# Filter out CUDA architectures < 8.0 for punica.
+#
+if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
+  set(APHRODITE_PUNICA_GPU_ARCHES)
+  foreach(ARCH ${APHRODITE_GPU_ARCHES})
+    string_to_ver(CODE_VER ${ARCH})
+    if (CODE_VER GREATER_EQUAL 8.0)
+      list(APPEND APHRODITE_PUNICA_GPU_ARCHES ${ARCH})
+    endif()
+  endforeach()
+  message(STATUS "Punica target arches: ${APHRODITE_PUNICA_GPU_ARCHES}")
+endif()
+
+if (APHRODITE_PUNICA_GPU_ARCHES)
+  define_gpu_extension_target(
+    _punica_C
+    DESTINATION aphrodite
+    LANGUAGE ${APHRODITE_GPU_LANG}
+    SOURCES ${APHRODITE_PUNICA_EXT_SRC}
+    COMPILE_FLAGS ${APHRODITE_PUNICA_GPU_FLAGS}
+    ARCHITECTURES ${APHRODITE_PUNICA_GPU_ARCHES}
+    WITH_SOABI)
+else()
+  message(WARNING "Unable to create _punica_C target because none of the "
+    "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 8.0")
+endif()
+
+#
+# _hadamard_C extension
+#
+
+set(APHRODITE_HADAMARD_EXT_SRC
+  "kernels/hadamard/fast_hadamard_transform.cpp"
+  "kernels/hadamard/fast_hadamard_transform_cuda.cu")
+
+#
+# Copy GPU compilation flags+update for hadamard
+#
+set(APHRODITE_HADAMARD_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
+list(APPEND APHRODITE_HADAMARD_GPU_FLAGS
+  "-U__CUDA_NO_HALF_OPERATORS__"
+  "-U__CUDA_NO_HALF_CONVERSIONS__"
+  "-U__CUDA_NO_BFLOAT16_OPERATORS__"
+  "-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
+  "-U__CUDA_NO_BFLOAT162_OPERATORS__"
+  "-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
+  "--expt-relaxed-constexpr"
+  "--expt-extended-lambda"
+  "--use_fast_math"
+  "--ptxas-options=-v"
+  "-lineinfo")
+
+#
+# Filter out CUDA architectures < 7.0 for hadamard.
+#
+if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
+  set(APHRODITE_HADAMARD_GPU_ARCHES)
+  foreach(ARCH ${APHRODITE_GPU_ARCHES})
+    string_to_ver(CODE_VER ${ARCH})
+    if (CODE_VER GREATER_EQUAL 6.0)
+      list(APPEND APHRODITE_HADAMARD_GPU_ARCHES ${ARCH})
+    endif()
+  endforeach()
+  message(STATUS "Hadamard target arches: ${APHRODITE_HADAMARD_GPU_ARCHES}")
+endif()
+
+if (APHRODITE_HADAMARD_GPU_ARCHES)
+  define_gpu_extension_target(
+    _hadamard_C
+    DESTINATION aphrodite
+    LANGUAGE ${APHRODITE_GPU_LANG}
+    SOURCES ${APHRODITE_HADAMARD_EXT_SRC}
+    COMPILE_FLAGS ${APHRODITE_HADAMARD_GPU_FLAGS}
+    ARCHITECTURES ${APHRODITE_HADAMARD_GPU_ARCHES}
+    WITH_SOABI)
+else()
+  message(WARNING "Unable to create _hadamard_C target because none of the "
+    "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 6.0")
+endif()
+
+#
+# Add the `default` target which detects which extensions should be
+# built based on platform/architecture.  This is the same logic that
+# setup.py uses to select which extensions should be built and should
+# be kept in sync.
+#
+# The `default` target makes direct use of cmake easier since knowledge
+# of which extensions are supported has been factored in, e.g.
+#
+# mkdir build && cd build
+# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
+# cmake --build . --target default
+#
+add_custom_target(default)
+
+if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
+  message(STATUS "Enabling C extension.")
+  add_dependencies(default _C)
+endif()
+
+if(APHRODITE_GPU_LANG STREQUAL "CUDA")
+  message(STATUS "Enabling moe extension.")
+  add_dependencies(default _moe_C)
+
+  # Enable punica if -DAPHRODITE_INSTALL_PUNICA_KERNELS=ON or
+  # APHRODITE_INSTALL_PUNICA_KERNELS is set in the environment and
+  # there are supported target arches.
+  if (APHRODITE_QUANT_EXT_SRC AND
+      (ENV{APHRODITE_INSTALL_QUANT_KERNELS} OR APHRODITE_INSTALL_QUANT_KERNELS))
+    message(STATUS "Enabling quant extension.")
+    add_dependencies(default _quant_C)
+  endif()
+  if (APHRODITE_PUNICA_GPU_ARCHES AND
+      (ENV{APHRODITE_INSTALL_PUNICA_KERNELS} OR APHRODITE_INSTALL_PUNICA_KERNELS))
+    message(STATUS "Enabling punica extension.")
+    add_dependencies(default _punica_C)
+  endif()
+  if (APHRODITE_HADAMARD_GPU_ARCHES AND
+      (ENV{APHRODITE_INSTALL_HADAMARD_KERNELS} OR APHRODITE_INSTALL_HADAMARD_KERNELS))
+    message(STATUS "Enabling hadamard extension.")
+    add_dependencies(default _hadamard_C)
+  endif()
+endif()

+ 2 - 1
MANIFEST.in

@@ -1,5 +1,6 @@
 include LICENSE
-include requirements.txt
+include requirements-common.txt
+include requirements-cuda.txt
 include aphrodite/endpoints/kobold/klite.embd
 
 recursive-include kernels *

+ 3 - 1
aphrodite/__init__.py

@@ -3,13 +3,15 @@ from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.ray_tools import initialize_ray_cluster
 from aphrodite.endpoints.llm import LLM
+from aphrodite.modeling.models import ModelRegistry
 from aphrodite.common.outputs import CompletionOutput, RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 
-__version__ = "0.5.2"
+__version__ = "0.5.3"
 
 __all__ = [
     "LLM",
+    "ModelRegistry",
     "SamplingParams",
     "RequestOutput",
     "CompletionOutput",

+ 15 - 0
aphrodite/attention/__init__.py

@@ -0,0 +1,15 @@
+from aphrodite.attention.backends.abstract import (
+    AttentionBackend,
+    AttentionMetadata,
+    AttentionMetadataPerStage,
+)
+from aphrodite.attention.layer import Attention
+from aphrodite.attention.selector import get_attn_backend
+
+__all__ = [
+    "AttentionBackend",
+    "AttentionMetadata",
+    "Attention",
+    "get_attn_backend",
+    "AttentionMetadataPerStage",
+]

+ 0 - 0
aphrodite/modeling/layers/attention/backends/__init__.py → aphrodite/attention/backends/__init__.py


+ 122 - 0
aphrodite/attention/backends/abstract.py

@@ -0,0 +1,122 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, fields
+from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
+
+import torch
+
+
+class AttentionBackend(ABC):
+    """Abstract class for attention backends."""
+
+    @staticmethod
+    @abstractmethod
+    def get_impl_cls() -> Type["AttentionImpl"]:
+        raise NotImplementedError
+
+    @staticmethod
+    @abstractmethod
+    def make_metadata(*args, **kwargs) -> "AttentionMetadata":
+        raise NotImplementedError
+
+    @staticmethod
+    @abstractmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        raise NotImplementedError
+
+    @staticmethod
+    @abstractmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        raise NotImplementedError
+
+    @staticmethod
+    @abstractmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        raise NotImplementedError
+
+
+@dataclass
+class AttentionMetadataPerStage:
+    """Attention metadata for a specific stage. I.e., prefill or decode."""
+
+    def asdict_zerocopy(self) -> Dict[str, Any]:
+        """Similar to dataclasses.asdict, but avoids deepcopying."""
+        # Note that if we add dataclasses as fields, they will need
+        # similar handling.
+        return {
+            field.name: getattr(self, field.name)
+            for field in fields(self)
+        }
+
+
+T = TypeVar("T", bound=AttentionMetadataPerStage)
+
+
+@dataclass
+class AttentionMetadata(Generic[T]):
+    """Attention metadata for prefill and decode batched together."""
+    # Total number of prefill requests.
+    num_prefills: int
+    # Number of prefill tokens.
+    num_prefill_tokens: int
+    # Number of decode tokens. Note that it is equivalent to the number of
+    # decode requests.
+    num_decode_tokens: int
+    # The attention metadata for prefill requests in a batch.
+    # None if there's no prefill requests in a batch.
+    prefill_metadata: Optional[T]
+    # The attention metadata for decode requests in a batch.
+    # None if there's no decode requests in a batch.
+    decode_metadata: Optional[T]
+    # (num_tokens,). The indices of the token slots that input tokens will be
+    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
+    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
+    # in block 0, and 1st slot in block 1, respectively.
+    slot_mapping: torch.Tensor
+    # The kv cache's data type.
+    kv_cache_dtype: str
+
+    def __post_init__(self):
+        if self.num_prefill_tokens > 0:
+            assert self.num_prefills > 0
+            assert self.prefill_metadata is not None
+        if self.num_decode_tokens > 0:
+            assert self.decode_metadata is not None
+
+
+class AttentionImpl(ABC):
+
+    @abstractmethod
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
+        kv_scale: float,
+    ) -> torch.Tensor:
+        raise NotImplementedError

+ 266 - 0
aphrodite/attention/backends/flash_attn.py

@@ -0,0 +1,266 @@
+"""Attention layer with Flash and PagedAttention.
+NOTE: At the moment, this file includes a lot of duplicated code from
+XFormers backend. The duplicated code will be removed once we use flash-attn or
+flashinfer for all the attention operations.
+"""
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from flash_attn import flash_attn_varlen_func
+
+from aphrodite.attention.backends.abstract import (
+    AttentionBackend,
+    AttentionImpl,
+    AttentionMetadata,
+    AttentionMetadataPerStage,
+)
+from aphrodite.attention.ops.paged_attn import (
+    PagedAttention,
+    PagedAttentionMetadata,
+)
+
+
+class FlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["FlashAttentionImpl"]:
+        return FlashAttentionImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
+        return FlashAttentionMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class FlashAttentionMetadata(AttentionMetadataPerStage,
+                             PagedAttentionMetadata):
+    """Metadata for FlashAttentionBackend.
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # Currently, input sequences can only contain all prompts
+    # or all decoding. True if all sequences are prompts.
+    is_prompt: bool
+    # (batch_size,). The prompt length per sequence. None if it is a decoding.
+    prompt_lens: Optional[List[int]]
+    # prompt_lens stored as a tensor.
+    prompt_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE: Definition of context_len, subquery_len, and seqlen.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seqlen ----------------------|
+    #                                   |- subquery_len -|
+
+    # WARNING: context_len has different definition depending on if it is
+    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
+    # When it is for decoding, it includes a new token.
+
+    # Maximum subquery length in the batch.
+    max_subquery_len: Optional[int]
+    # Maximum prompt length in the batch.
+    max_prompt_len: Optional[int]
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    subquery_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+
+class FlashAttentionImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prefill_tokens ----------------->|	
+    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
+    Otherwise, the layout is as follows:	
+    |<----------------- num_decode_tokens ------------------>|	
+    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+    If chunked prefill is enabled, prefill tokens and decode tokens can be
+    batched together in a flattened 1D query.
+    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
+    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+    Currently, cuda graph is disabled for chunked prefill, meaning there's no
+    padding between prefill and decode tokens.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata[FlashAttentionMetadata],
+        kv_scale: float,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+            PagedAttention.write_to_paged_cache(key, value, key_cache,
+                                                value_cache,
+                                                attn_metadata.slot_mapping,
+                                                attn_metadata.kv_cache_dtype,
+                                                kv_scale)
+
+        num_prefill_tokens = attn_metadata.num_prefill_tokens
+        num_decode_tokens = attn_metadata.num_decode_tokens
+        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+        output = torch.empty_like(query)
+        # Query for decode. KV is not needed because it is already cached.
+        decode_query = query[num_prefill_tokens:]
+        # QKV for prefill.
+        query = query[:num_prefill_tokens]
+        key = key[:num_prefill_tokens]
+        value = value[:num_prefill_tokens]
+
+        assert query.shape[0] == num_prefill_tokens
+        assert decode_query.shape[0] == num_decode_tokens
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+            # Prompt run.
+            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
+                # normal attention
+                # When block_tables are not filled, it means q and k are the
+                # prompt, and they have the same length.
+                out = flash_attn_varlen_func(
+                    q=query,
+                    k=key,
+                    v=value,
+                    cu_seqlens_q=prefill_meta.seq_start_loc,
+                    cu_seqlens_k=prefill_meta.seq_start_loc,
+                    max_seqlen_q=prefill_meta.max_prompt_len,
+                    max_seqlen_k=prefill_meta.max_prompt_len,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    window_size=self.sliding_window,
+                    alibi_slopes=self.alibi_slopes,
+                )
+                assert output[:num_prefill_tokens].shape == out.shape
+                output[:num_prefill_tokens] = out
+            else:
+                # prefix-enabled attention
+                # TODO: this triton kernel has regression issue (broke) to
+                # deal with different data types between KV and FP8 KV cache,
+                # to be addressed separately.
+                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    prefill_meta.block_tables,
+                    prefill_meta.subquery_start_loc,
+                    prefill_meta.prompt_lens_tensor,
+                    prefill_meta.context_lens,
+                    prefill_meta.max_subquery_len,
+                    self.alibi_slopes,
+                )
+        if decode_meta := attn_metadata.decode_metadata:
+            # Decoding run.
+            output[num_prefill_tokens:] = PagedAttention.forward_decode(
+                decode_query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.context_lens,
+                decode_meta.max_context_len,
+                attn_metadata.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+            )
+
+        # Reshape the output tensor.
+        return output.view(num_tokens, hidden_size)

+ 373 - 0
aphrodite/attention/backends/rocm_flash_attn.py

@@ -0,0 +1,373 @@
+"""Attention layer ROCm GPUs."""
+import os
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from loguru import logger
+
+from aphrodite.attention.backends.abstract import (
+    AttentionBackend,
+    AttentionImpl,
+    AttentionMetadata,
+    AttentionMetadataPerStage,
+)
+from aphrodite.attention.ops.paged_attn import (
+    PagedAttention,
+    PagedAttentionMetadata,
+)
+
+
+class ROCmFlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
+        return ROCmFlashAttentionImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
+        return ROCmFlashAttentionMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
+                                 PagedAttentionMetadata):
+    """Metadata for FlashAttentionBackend.
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # Currently, input sequences can only contain all prompts
+    # or all decoding. True if all sequences are prompts.
+    is_prompt: bool
+    # (batch_size,). The prompt length per sequence. None if it is a decoding.
+    prompt_lens: Optional[List[int]]
+    # prompt_lens stored as a tensor.
+    prompt_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE: Definition of context_len, subquery_len, and seqlen.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seqlen ----------------------|
+    #                                   |- subquery_len -|
+
+    # WARNING: context_len has different definition depending on if it is
+    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
+    # When it is for decoding, it includes a new token.
+
+    # Maximum subquery length in the batch.
+    max_subquery_len: Optional[int]
+    # Maximum prompt length in the batch.
+    max_prompt_len: Optional[int]
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    subquery_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+
+class ROCmFlashAttentionImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prompt_tokens -------------->|
+    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+    Otherwise, the layout is as follows:
+    |<------------------ num_generation_tokens (M) ----------------->|
+    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+
+    If chunked prefill is enabled, prefill tokens and decode tokens can be
+    batched together in a flattened 1D query.
+    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
+    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
+    Currently, cuda graph is disabled for chunked prefill, meaning there's no
+    padding between prefill and decode tokens.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
+        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
+        self.use_triton_flash_attn = (os.environ.get(
+            "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
+                                      in ("true", "1"))
+        if self.use_naive_attn:
+            # AMD Radeon 7900 series (gfx1100) currently does not support
+            # xFormers nor FlashAttention. As a temporary workaround, we use
+            # naive PyTorch implementation of attention.
+            self.attn_fuc = _naive_attention
+            logger.debug("Using naive attention in ROCmBackend")
+        elif self.use_triton_flash_attn:
+            from aphrodite.attention.ops.triton_flash_attn import (  # noqa: F401
+                triton_attention, )
+            self.attn_func = triton_attention
+            logger.debug("Using Triton FA in ROCmBackend")
+        else:
+            from flash_attn import flash_attn_varlen_func  # noqa: F401
+            self.attn_func = flash_attn_varlen_func
+            logger.debug("Using CK FA in ROCmBackend")
+
+    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
+        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
+        tokens, n_kv_heads, head_dim = x.shape
+        return (x[:, :,
+                  None, :].expand(tokens, n_kv_heads, n_rep,
+                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
+                                                    head_dim))
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
+        kv_scale: float = 1.0,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+            PagedAttention.write_to_paged_cache(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping,
+                attn_metadata.kv_cache_dtype,
+                kv_scale,
+            )
+
+        num_prefill_tokens = attn_metadata.num_prefill_tokens
+        num_decode_tokens = attn_metadata.num_decode_tokens
+        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+        output = torch.empty_like(query)
+        # Query for decode. KV is not needed because it is already cached.
+        decode_query = query[num_prefill_tokens:]
+        # QKV for prefill.
+        query = query[:num_prefill_tokens]
+        key = key[:num_prefill_tokens]
+        value = value[:num_prefill_tokens]
+
+        assert query.shape[0] == num_prefill_tokens
+        assert decode_query.shape[0] == num_decode_tokens
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+            # Prompt run.
+            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
+                # triton attention
+                # When block_tables are not filled, it means q and k are the
+                # prompt, and they have the same length.
+                if self.use_naive_attn or self.use_triton_flash_attn:
+                    if self.num_kv_heads != self.num_heads:
+                        # Interleave for MQA workaround.
+                        key = self.repeat_kv(key, self.num_queries_per_kv)
+                        value = self.repeat_kv(value, self.num_queries_per_kv)
+                    if self.use_naive_attn:
+                        out = self.attn_fuc(
+                            query,
+                            key,
+                            value,
+                            prefill_meta.prompt_lens,
+                            self.scale,
+                        )
+                        assert output[:num_prefill_tokens].shape == out.shape
+                        output[:num_prefill_tokens] = out
+                    else:
+                        out, _ = self.attn_func(
+                            query,
+                            key,
+                            value,
+                            None,
+                            prefill_meta.seq_start_loc,
+                            prefill_meta.seq_start_loc,
+                            prefill_meta.max_prompt_len,
+                            prefill_meta.max_prompt_len,
+                            True,
+                            self.scale,
+                        )
+                        assert output[:num_prefill_tokens].shape == out.shape
+                        output[:num_prefill_tokens] = out
+                else:
+                    out = self.attn_func(
+                        q=query,
+                        k=key,
+                        v=value,
+                        cu_seqlens_q=prefill_meta.seq_start_loc,
+                        cu_seqlens_k=prefill_meta.seq_start_loc,
+                        max_seqlen_q=prefill_meta.max_prompt_len,
+                        max_seqlen_k=prefill_meta.max_prompt_len,
+                        softmax_scale=self.scale,
+                        causal=True,
+                    )
+                    assert output[:num_prefill_tokens].shape == out.shape
+                    output[:num_prefill_tokens] = out
+            else:
+                # prefix-enabled attention
+                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    prefill_meta.block_tables,
+                    prefill_meta.subquery_start_loc,
+                    prefill_meta.prompt_lens_tensor,
+                    prefill_meta.context_lens,
+                    prefill_meta.max_subquery_len,
+                    self.alibi_slopes,
+                )
+
+        if decode_meta := attn_metadata.decode_metadata:
+            # Decoding run.
+            output[num_prefill_tokens:] = PagedAttention.forward_decode(
+                decode_query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.context_lens,
+                decode_meta.max_context_len,
+                attn_metadata.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+            )
+        # Reshape the output tensor.
+        return output.view(num_tokens, hidden_size)
+
+
+def _naive_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    prompt_lens: List[int],
+    scale: float,
+) -> torch.Tensor:
+    num_tokens = query.shape[0]  # noqa: F841
+    output = torch.empty_like(query)
+    start = 0
+    for _, prompt_len in enumerate(prompt_lens):
+        end = start + prompt_len
+        out = _naive_masked_attention(
+            query[start:end],
+            key[start:end],
+            value[start:end],
+            scale,
+        )
+        # TODO: Unnecessary copy. Optimize.
+        output[start:end].copy_(out)
+        start += prompt_len
+
+    # Using view got RuntimeError: view size is not compatible
+    # with input tensor's size and stride (at least one
+    # dimension spans across two contiguous subspaces).
+    # Use reshape instead.
+    return output
+
+
+def _naive_masked_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    scale: float,
+) -> torch.Tensor:
+    seq_len, head_size, head_dim = query.shape
+    attn_mask = torch.triu(torch.ones(seq_len,
+                                      seq_len,
+                                      dtype=query.dtype,
+                                      device=query.device),
+                           diagonal=1)
+    attn_mask = attn_mask * torch.finfo(query.dtype).min
+
+    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
+    attn_weights = attn_weights + attn_mask.float()
+    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
+    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
+    return out

+ 254 - 0
aphrodite/attention/backends/sdpa.py

@@ -0,0 +1,254 @@
+""" Attention layer with torch scaled_dot_product_attention
+    and PagedAttention."""
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from torch.nn.functional import scaled_dot_product_attention
+
+from aphrodite.attention.backends.abstract import (
+    AttentionBackend,
+    AttentionImpl,
+    AttentionMetadata,
+    AttentionMetadataPerStage,
+)
+from aphrodite.attention.ops.paged_attn import (
+    PagedAttention,
+    PagedAttentionMetadata,
+)
+
+
+class TorchSDPABackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
+        return TorchSDPABackendImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
+        return TorchSDPAMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
+                        AttentionMetadataPerStage):
+    """Metadata for TorchSDPABackend.
+    """
+    # Currently, input sequences can only contain all prompts
+    # or all decoding. True if all sequences are prompts.
+    is_prompt: bool
+    slot_mapping: torch.Tensor
+    prompt_lens: Optional[List[int]]
+
+    def __post_init__(self):
+        # Set during the execution of the first attention op.
+        # It is a list because it is needed to set per prompt
+        # when alibi slopes is used. It is because of the limitation
+        # from xformer API.
+        # will not appear in the __repr__ and __init__
+        self.attn_bias: Optional[List[torch.Tensor]] = None
+
+
+class TorchSDPABackendImpl(AttentionImpl):
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            assert len(alibi_slopes) == num_heads
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+        self.need_mask = (self.alibi_slopes is not None
+                          or self.sliding_window is not None)
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: TorchSDPAMetadata,
+        kv_scale: float,
+    ) -> torch.Tensor:
+        """Forward pass with torch SDPA and PagedAttention.
+
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+            PagedAttention.write_to_paged_cache(key, value, key_cache,
+                                                value_cache,
+                                                attn_metadata.slot_mapping,
+                                                attn_metadata.kv_cache_dtype,
+                                                kv_scale)
+
+        if attn_metadata.is_prompt:
+            if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
+                if self.num_kv_heads != self.num_heads:
+                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
+                    value = value.repeat_interleave(self.num_queries_per_kv,
+                                                    dim=1)
+
+                if attn_metadata.attn_bias is None:
+                    if self.alibi_slopes is not None:
+                        att_masks = _make_alibi_bias(
+                            self.alibi_slopes, query.dtype,
+                            attn_metadata.prompt_lens)  # type: ignore
+                    elif self.sliding_window is not None:
+                        att_masks = _make_sliding_window_bias(
+                            attn_metadata.prompt_lens, self.sliding_window,
+                            query.dtype)  # type: ignore
+                    else:
+                        att_masks = [None] * len(attn_metadata.prompt_lens)
+                    attn_metadata.attn_bias = att_masks
+
+                query = query.movedim(0, query.dim() - 2)
+                key = key.movedim(0, key.dim() - 2)
+                value = value.movedim(0, value.dim() - 2)
+
+                start = 0
+                output = torch.empty(
+                    (num_tokens, self.num_heads, self.head_size),
+                    dtype=query.dtype)
+                for prompt_len, mask in zip(attn_metadata.prompt_lens,
+                                            attn_metadata.attn_bias):
+                    end = start + prompt_len
+                    sub_out = scaled_dot_product_attention(
+                        query[:, start:end, :],
+                        key[:, start:end, :],
+                        value[:, start:end, :],
+                        attn_mask=mask,
+                        dropout_p=0.0,
+                        is_causal=not self.need_mask,
+                        scale=self.scale).movedim(query.dim() - 2, 0)
+                    output[start:end, :, :] = sub_out
+                    start = end
+            else:
+                # prefix-enabled attention
+                raise RuntimeError(
+                    "Torch SDPA backend doesn't support prefix decoding.")
+
+        else:
+            # Decoding run.
+            output = PagedAttention.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                attn_metadata.block_tables,
+                attn_metadata.context_lens,
+                attn_metadata.max_context_len,
+                attn_metadata.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+            )
+
+        # Reshape the output tensor.
+        return output.view(-1, self.num_heads * self.head_size)
+
+
+def _make_alibi_bias(
+    alibi_slopes: torch.Tensor,
+    dtype: torch.dtype,
+    prompt_lens: List[int],
+) -> List[torch.Tensor]:
+    attn_biases = []
+    for prompt_len in prompt_lens:
+        bias = torch.arange(prompt_len, dtype=dtype)
+        # NOTE: HF uses
+        #     `bias = bias[None, :].repeat(prompt_len, 1)`
+        # here. We find that both biases give the same results, but
+        # the bias below more accurately follows the original ALiBi
+        # paper.
+        bias = bias[None, :] - bias[:, None]
+
+        num_heads = alibi_slopes.shape[0]
+        bias = bias[None, :].repeat((num_heads, 1, 1))
+        bias.mul_(alibi_slopes[:, None, None])
+        inf_mask = torch.empty(
+            (1, prompt_len, prompt_len),
+            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
+        attn_biases.append((bias + inf_mask).to(dtype))
+
+    return attn_biases
+
+
+def _make_sliding_window_bias(
+    prompt_lens: List[int],
+    window_size: Optional[int],
+    dtype: torch.dtype,
+) -> List[torch.Tensor]:
+    attn_biases = []
+    for prompt_len in prompt_lens:
+        tensor = torch.full(
+            (1, prompt_len, prompt_len),
+            dtype=dtype,
+            fill_value=1,
+        )
+        shift = 0
+        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
+        if window_size is not None:
+            mask = torch.triu(mask, diagonal=shift - window_size + 1)
+        mask = torch.log(mask)
+        attn_biases.append(mask.to(dtype))
+
+    return attn_biases

+ 396 - 0
aphrodite/attention/backends/xformers.py

@@ -0,0 +1,396 @@
+"""Attention layer with xFormers and PagedAttention."""
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from xformers import ops as xops
+from xformers.ops.fmha.attn_bias import (
+    AttentionBias,
+    BlockDiagonalCausalMask,
+    LowerTriangularMaskWithTensorBias,
+)
+
+from aphrodite.attention.backends.abstract import (
+    AttentionBackend,
+    AttentionImpl,
+    AttentionMetadata,
+    AttentionMetadataPerStage,
+)
+from aphrodite.attention.ops.paged_attn import (
+    PagedAttention,
+    PagedAttentionMetadata,
+)
+
+
+class XFormersBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["XFormersImpl"]:
+        return XFormersImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "XFormersMetadata":
+        return XFormersMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
+    """Metadata for XFormersbackend.
+
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # Currently, input sequences can only contain all prompts
+    # or all decoding. True if all sequences are prompts.
+    is_prompt: bool
+    # (batch_size,). The prompt length per sequence. None if it is a decoding.
+    prompt_lens: Optional[List[int]]
+    # prompt_lens stored as a tensor.
+    prompt_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE: Definition of context_len, subquery_len, and seqlen.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seqlen ----------------------|
+    #                                   |- subquery_len -|
+
+    # WARNING(sang): context_len has different definition depending on if it is
+    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
+    # When it is for decoding, it includes a new token.
+
+    # Maximum subquery length in the batch.
+    max_subquery_len: Optional[int]
+    # FIXME: It is for flash attn.
+    # Maximum prompt length in the batch.
+    max_prompt_len: Optional[int]
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    subquery_start_loc: Optional[torch.Tensor]
+    # FIXME: It is for flash attn.
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    def __post_init__(self):
+        # Set during the execution of the first attention op.
+        # It is a list because it is needed to set per prompt
+        # when alibi slopes is used. It is because of the limitation
+        # from xformer API.
+        # will not appear in the __repr__ and __init__
+        self.attn_bias: Optional[List[AttentionBias]] = None
+
+
+class XFormersImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prefill_tokens ----------------->|	
+    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
+
+    Otherwise, the layout is as follows:	
+    |<----------------- num_decode_tokens ------------------>|	
+    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
+
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+
+    If chunked prefill is enabled, prefill tokens and decode tokens can be
+    batched together in a flattened 1D query.
+
+    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
+    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+
+    Currently, cuda graph is disabled for chunked prefill, meaning there's no
+    padding between prefill and decode tokens.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: AttentionMetadata[XFormersMetadata],
+        kv_scale: float,
+    ) -> torch.Tensor:
+        """Forward pass with xFormers and PagedAttention.
+
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+            PagedAttention.write_to_paged_cache(key, value, key_cache,
+                                                value_cache,
+                                                attn_metadata.slot_mapping,
+                                                attn_metadata.kv_cache_dtype,
+                                                kv_scale)
+
+        num_prefill_tokens = attn_metadata.num_prefill_tokens
+        num_decode_tokens = attn_metadata.num_decode_tokens
+        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+        output = torch.empty_like(query)
+        # Query for decode. KV is not needed because it is already cached.
+        decode_query = query[num_prefill_tokens:]
+        # QKV for prefill.
+        query = query[:num_prefill_tokens]
+        key = key[:num_prefill_tokens]
+        value = value[:num_prefill_tokens]
+
+        assert query.shape[0] == num_prefill_tokens
+        assert decode_query.shape[0] == num_decode_tokens
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+            # Prompt run.
+            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
+                # normal attention.
+                # block tables are empty if the prompt does not have a cached
+                # prefix.
+                out = self._run_memory_efficient_xformers_forward(
+                    query, key, value, prefill_meta)
+                assert out.shape == output[:num_prefill_tokens].shape
+                output[:num_prefill_tokens] = out
+            else:
+                # prefix-enabled attention
+                # TODO: this triton kernel has regression issue (broke) to
+                # deal with different data types between KV and FP8 KV cache,
+                # to be addressed separately.
+                out = PagedAttention.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    prefill_meta.block_tables,
+                    prefill_meta.subquery_start_loc,
+                    prefill_meta.prompt_lens_tensor,
+                    prefill_meta.context_lens,
+                    prefill_meta.max_subquery_len,
+                    self.alibi_slopes,
+                )
+                assert output[:num_prefill_tokens].shape == out.shape
+                output[:num_prefill_tokens] = out
+
+        if decode_meta := attn_metadata.decode_metadata:
+            output[num_prefill_tokens:] = PagedAttention.forward_decode(
+                decode_query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.context_lens,
+                decode_meta.max_context_len,
+                attn_metadata.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+            )
+
+        # Reshape the output tensor.
+        return output.view(-1, self.num_heads * self.head_size)
+
+    def _run_memory_efficient_xformers_forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        attn_metadata: XFormersMetadata,
+    ) -> torch.Tensor:
+        """Attention for 1D query of multiple prompts. Multiple prompt
+        tokens are flattened in to `query` input.
+
+        See https://facebookresearch.github.io/xformers/components/ops.html
+        for API spec.
+
+        Args:
+            output: shape = [num_prefill_tokens, num_heads, head_size]
+            query: shape = [num_prefill_tokens, num_heads, head_size]
+            key: shape = [num_prefill_tokens, num_kv_heads, head_size]
+            value: shape = [num_prefill_tokens, num_kv_heads, head_size]
+            attn_metadata: Metadata for attention.
+        """
+        original_query = query
+        if self.num_kv_heads != self.num_heads:
+            # GQA/MQA requires the shape [B, M, G, H, K].
+            # Note that the output also has the same shape (which is different
+            # from a spec from the doc).
+            query = query.view(query.shape[0], self.num_kv_heads,
+                               self.num_queries_per_kv, query.shape[-1])
+            key = key[:, :,
+                      None, :].expand(key.shape[0], self.num_kv_heads,
+                                      self.num_queries_per_kv, key.shape[-1])
+            value = value[:, :,
+                          None, :].expand(value.shape[0], self.num_kv_heads,
+                                          self.num_queries_per_kv,
+                                          value.shape[-1])
+        # Set attention bias if not provided. This typically happens at
+        # the very attention layer of every iteration.
+        # FIXME: This is a hack.
+        if attn_metadata.attn_bias is None:
+            if self.alibi_slopes is None:
+                attn_bias = BlockDiagonalCausalMask.from_seqlens(
+                    attn_metadata.prompt_lens)
+                if self.sliding_window is not None:
+                    attn_bias = attn_bias.make_local_attention(
+                        self.sliding_window)
+                attn_metadata.attn_bias = [attn_bias]
+            else:
+                attn_metadata.attn_bias = _make_alibi_bias(
+                    self.alibi_slopes, self.num_kv_heads, query.dtype,
+                    attn_metadata.prompt_lens)
+
+        # No alibi slopes.
+        # TODO: Too many view operations. Let's try to reduce
+        # them in the future for code readability.
+        if self.alibi_slopes is None:
+            # Add the batch dimension.
+            query = query.unsqueeze(0)
+            key = key.unsqueeze(0)
+            value = value.unsqueeze(0)
+            out = xops.memory_efficient_attention_forward(
+                query,
+                key,
+                value,
+                attn_bias=attn_metadata.attn_bias[0],
+                p=0.0,
+                scale=self.scale)
+            return out.view_as(original_query)
+
+        # Attention with alibi slopes.
+        # FIXME: Because xformers does not support dynamic sequence
+        # lengths with custom attention bias, we process each prompt one by
+        # one. This is inefficient, especially when we have many short prompts.
+        output = torch.empty_like(original_query)
+        start = 0
+        for i, prompt_len in enumerate(attn_metadata.prompt_lens):
+            end = start + prompt_len
+            out = xops.memory_efficient_attention_forward(
+                query[None, start:end],
+                key[None, start:end],
+                value[None, start:end],
+                attn_bias=attn_metadata.attn_bias[i],
+                p=0.0,
+                scale=self.scale)
+            # TODO: Unnecessary copy. Optimize.
+            output[start:end].copy_(out.view_as(original_query[start:end]))
+            start += prompt_len
+        return output
+
+
+def _make_alibi_bias(
+    alibi_slopes: torch.Tensor,
+    num_kv_heads: int,
+    dtype: torch.dtype,
+    prompt_lens: List[int],
+) -> LowerTriangularMaskWithTensorBias:
+    attn_biases = []
+    for prompt_len in prompt_lens:
+        bias = torch.arange(prompt_len, dtype=dtype)
+        # NOTE: HF uses
+        #     `bias = bias[None, :].repeat(prompt_len, 1)`
+        # here. We find that both biases give the same results, but
+        # the bias below more accurately follows the original ALiBi
+        # paper.
+        # Calculate a matrix where each element represents ith element- jth
+        # element.
+        bias = bias[None, :] - bias[:, None]
+
+        padded_len = (prompt_len + 7) // 8 * 8
+        num_heads = alibi_slopes.shape[0]
+        bias = torch.empty(
+            1,  # batch size
+            num_heads,
+            prompt_len,
+            padded_len,
+            device=alibi_slopes.device,
+            dtype=dtype,
+        )[:, :, :, :prompt_len].copy_(bias)
+        bias.mul_(alibi_slopes[:, None, None])
+        if num_heads != num_kv_heads:
+            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
+        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
+
+    return attn_biases

+ 47 - 0
aphrodite/attention/layer.py

@@ -0,0 +1,47 @@
+"""Attention layer."""
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from aphrodite.attention.backends.abstract import (AttentionMetadata,
+                                                   AttentionMetadataPerStage)
+from aphrodite.attention.selector import get_attn_backend
+
+
+class Attention(nn.Module):
+    """Attention layer.
+    This class takes query, key, and value tensors as input. The input tensors
+    can either contain prompt tokens or generation tokens.
+    The class does the following:
+    1. Store the input key and value tensors in the KV cache.
+    2. Perform (multi-head/multi-query/grouped-query) attention.
+    3. Return the output tensor.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+        self.backend = get_attn_backend(torch.get_default_dtype())
+        impl_cls = self.backend.get_impl_cls()
+        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
+                             alibi_slopes, sliding_window)
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
+        kv_scale: float = 1.0,
+    ) -> torch.Tensor:
+        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
+                                 kv_scale)

+ 0 - 0
aphrodite/modeling/layers/attention/ops/__init__.py → aphrodite/attention/ops/__init__.py


+ 216 - 0
aphrodite/attention/ops/paged_attn.py

@@ -0,0 +1,216 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from aphrodite._C import cache_ops
+from aphrodite._C import ops
+from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
+
+# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
+_PARTITION_SIZE = 512
+
+
+@dataclass
+class PagedAttentionMetadata:
+    """Metadata for PagedAttention."""
+    # (batch_size,). The length of context (tokens stored in KV cache) per
+    # sequence. WARNING: When it is a prefill request, it doesn't include new
+    # tokens. When it is for decoding, it includes a new token.
+    context_lens: Optional[torch.Tensor]
+    # Maximum context length in the batch.
+    max_context_len: Optional[int]
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
+
+class PagedAttention:
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [64, 80, 96, 112, 128, 256]
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return (2, num_blocks, block_size * num_kv_heads * head_size)
+
+    @staticmethod
+    def split_kv_cache(
+        kv_cache: torch.Tensor,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        x = 16 // kv_cache.element_size()
+        num_blocks = kv_cache.shape[1]
+
+        key_cache = kv_cache[0]
+        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
+                                   -1, x)
+        value_cache = kv_cache[1]
+        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
+        return key_cache, value_cache
+
+    @staticmethod
+    def write_to_paged_cache(
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        slot_mapping: torch.Tensor,
+        kv_cache_dtype: str,
+        kv_scale: float,
+    ) -> None:
+        cache_ops.reshape_and_cache(
+            key,
+            value,
+            key_cache,
+            value_cache,
+            slot_mapping.flatten(),
+            kv_cache_dtype,
+            kv_scale,
+        )
+
+    @staticmethod
+    def forward_decode(
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        block_tables: torch.Tensor,
+        context_lens: torch.Tensor,
+        max_context_len: int,
+        kv_cache_dtype: str,
+        num_kv_heads: int,
+        scale: float,
+        alibi_slopes: Optional[torch.Tensor],
+        kv_scale: float,
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+
+        block_size = value_cache.shape[3]
+        num_seqs, num_heads, head_size = query.shape
+        max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
+                              _PARTITION_SIZE)
+        # NOTE: We use a simple heuristic to decide whether to use
+        # PagedAttention V1 or V2. If the number of partitions is 1, we use
+        # V1 to avoid the overhead of reduction. Also, if the number of
+        # sequences or heads is large, we use V1 since there is enough work
+        # to parallelize.
+        # TODO: Tune this heuristic.
+        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
+        use_v1 = (max_context_len <= 8192
+                  and (max_num_partitions == 1 or num_seqs * num_heads > 512))
+        if use_v1:
+            # Run PagedAttention V1.
+            ops.paged_attention_v1(
+                output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                block_tables,
+                context_lens,
+                block_size,
+                max_context_len,
+                alibi_slopes,
+                kv_cache_dtype,
+                kv_scale,
+            )
+        else:
+            # Run PagedAttention V2.
+            assert _PARTITION_SIZE % block_size == 0
+            tmp_output = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions, head_size),
+                dtype=output.dtype,
+                device=output.device,
+            )
+            exp_sums = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions),
+                dtype=torch.float32,
+                device=output.device,
+            )
+            max_logits = torch.empty_like(exp_sums)
+            ops.paged_attention_v2(
+                output,
+                exp_sums,
+                max_logits,
+                tmp_output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                block_tables,
+                context_lens,
+                block_size,
+                max_context_len,
+                alibi_slopes,
+                kv_cache_dtype,
+                kv_scale,
+            )
+        return output
+
+    @staticmethod
+    def forward_prefix(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        block_tables: torch.Tensor,
+        subquery_start_loc: torch.Tensor,
+        prompt_lens_tensor: torch.Tensor,
+        context_lens: torch.Tensor,
+        max_subquery_len: int,
+        alibi_slopes: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+        context_attention_fwd(
+            query,
+            key,
+            value,
+            output,
+            key_cache,
+            value_cache,
+            block_tables,
+            # subquery_start_loc is (batch_size + 1,)
+            subquery_start_loc[:-1],
+            prompt_lens_tensor,
+            context_lens,
+            max_subquery_len,
+            alibi_slopes,
+        )
+        return output
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        src_key_cache = src_kv_cache[0]
+        dst_key_cache = dst_kv_cache[0]
+        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+
+        src_value_cache = src_kv_cache[1]
+        dst_value_cache = dst_kv_cache[1]
+        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        key_caches = [kv_cache[0] for kv_cache in kv_caches]
+        value_caches = [kv_cache[1] for kv_cache in kv_caches]
+        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)

+ 0 - 0
aphrodite/modeling/layers/attention/ops/prefix_prefill.py → aphrodite/attention/ops/prefix_prefill.py


+ 805 - 0
aphrodite/attention/ops/triton_flash_attn.py

@@ -0,0 +1,805 @@
+#!/usr/bin/env python
+"""
+Fused Attention
+===============
+This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
+(https://tridao.me/publications/flash2/flash2.pdf)
+Credits: OpenAI kernel team, AMD ML Frameworks Triton team
+Features supported:
+1) Fwd with causal masking
+2) Any sequence lengths without padding (currently fwd kernel only)
+3) Support for different sequence lengths for q and k
+4) Nested tensor API currently does not support dropout or bias.
+Not currently supported:
+1) Non power of two head dims
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+torch_dtype: tl.constexpr = torch.float16
+
+
+@triton.jit
+def cdiv_fn(x, y):
+    return (x + y - 1) // y
+
+
+@triton.jit
+def max_fn(x, y):
+    return tl.math.max(x, y)
+
+
+@triton.jit
+def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
+    ms = tl.arange(0, m)
+    ns = tl.arange(0, n)
+    return philox_offset + ms[:, None] * stride + ns[None, :]
+
+
+@triton.jit
+def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
+    rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
+                                  stride).to(tl.uint32)
+    # TODO: use tl.randint for better performance
+    return tl.rand(philox_seed, rng_offsets)
+
+
+@triton.jit
+def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
+    rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
+                             stride)
+    rng_keep = rng_output > dropout_p
+    return rng_keep
+
+
+@triton.jit
+def load_fn(block_ptr, first, second, pad):
+    if first and second:
+        tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
+    elif first:
+        tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
+    elif second:
+        tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
+    else:
+        tensor = tl.load(block_ptr)
+    return tensor
+
+
+@triton.jit
+def _attn_fwd_inner(
+    acc,
+    l_i,
+    m_i,
+    q,
+    K_block_ptr,
+    V_block_ptr,
+    start_m,
+    actual_seqlen_k,
+    dropout_p,
+    philox_seed,
+    batch_philox_offset,
+    encoded_softmax_block_ptr,
+    block_min,
+    block_max,
+    offs_n_causal,
+    masked_blocks,
+    n_extra_tokens,
+    bias_ptr,
+    IS_CAUSAL: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    OFFS_M: tl.constexpr,
+    OFFS_N: tl.constexpr,
+    PRE_LOAD_V: tl.constexpr,
+    MASK_STEPS: tl.constexpr,
+    ENABLE_DROPOUT: tl.constexpr,
+    RETURN_ENCODED_SOFTMAX: tl.constexpr,
+    PADDED_HEAD: tl.constexpr,
+):
+    # loop over k, v, and update accumulator
+    for start_n in range(block_min, block_max, BLOCK_N):
+        # For padded blocks, we will overrun the tensor size if
+        # we load all BLOCK_N. For others, the blocks are all within range.
+        k = load_fn(
+            K_block_ptr,
+            PADDED_HEAD,
+            MASK_STEPS and (n_extra_tokens != 0),
+            "zero",
+        )
+        if PRE_LOAD_V:
+            v = load_fn(
+                V_block_ptr,
+                MASK_STEPS and (n_extra_tokens != 0),
+                PADDED_HEAD,
+                "zero",
+            )
+        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+        # We start from end of seqlen_k so only the first iteration would need
+        # to be checked for padding if it is not a multiple of block_n
+        # TODO: This can be optimized to only be true for the padded block.
+        if MASK_STEPS:  # noqa: SIM102
+            # If this is the last block / iteration, we want to
+            # mask if the sequence length is not a multiple of block size
+            # a solution is to always do BLOCK_M // BLOCK_N + 1 steps
+            # if not is_modulo_mn. last step might get wasted but that is okay.
+            # check if this masking works for that case.
+            if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
+                boundary_m = tl.full([BLOCK_M],
+                                     actual_seqlen_k,
+                                     dtype=tl.int32)
+                size_n = start_n + OFFS_N[None, :]
+                mask = size_n < boundary_m[:, None]
+                qk = tl.where(mask, qk, float("-inf"))
+        if IS_CAUSAL:
+            causal_boundary = start_n + offs_n_causal
+            causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
+            qk = tl.where(causal_mask, qk, float("-inf"))
+        # -- compute qk ----
+        qk += tl.dot(q, k)
+        if bias_ptr is not None:
+            bias = load_fn(bias_ptr, False, MASK_STEPS
+                           and (n_extra_tokens != 0), "zero")
+            # While bias is added after multiplying qk with sm_scale, our
+            # optimization to use 2^x instead of e^x results in an additional
+            # scale factor of log2(e) which we must also multiply the bias with.
+            qk += bias * 1.44269504089
+        m_ij = tl.maximum(m_i, tl.max(qk, 1))
+        qk = qk - m_ij[:, None]
+        p = tl.math.exp2(qk)
+
+        # CAVEAT: Must update l_ij before applying dropout
+        l_ij = tl.sum(p, 1)
+        if ENABLE_DROPOUT:
+            philox_offset = (batch_philox_offset +
+                             start_m * BLOCK_M * actual_seqlen_k + start_n -
+                             BLOCK_N)
+            keep = dropout_mask(
+                philox_seed,
+                philox_offset,
+                dropout_p,
+                BLOCK_M,
+                BLOCK_N,
+                actual_seqlen_k,
+            )
+            if RETURN_ENCODED_SOFTMAX:
+                tl.store(
+                    encoded_softmax_block_ptr,
+                    tl.where(keep, p,
+                             -p).to(encoded_softmax_block_ptr.type.element_ty),
+                )
+            p = tl.where(keep, p, 0.0)
+        elif RETURN_ENCODED_SOFTMAX:
+            tl.store(
+                encoded_softmax_block_ptr,
+                p.to(encoded_softmax_block_ptr.type.element_ty),
+            )
+        # -- update output accumulator --
+        alpha = tl.math.exp2(m_i - m_ij)
+        acc = acc * alpha[:, None]
+        if not PRE_LOAD_V:
+            v = load_fn(
+                V_block_ptr,
+                MASK_STEPS and (n_extra_tokens != 0),
+                PADDED_HEAD,
+                "zero",
+            )
+        # -- update m_i and l_i
+        l_i = l_i * alpha + l_ij
+        # update m_i and l_i
+        m_i = m_ij
+        acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
+        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
+        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
+        if bias_ptr is not None:
+            bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
+        if RETURN_ENCODED_SOFTMAX:
+            encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
+                                                   (0, BLOCK_N))
+    return acc, l_i, m_i
+
+
+@triton.autotune(
+    configs=[
+        triton.Config(
+            {
+                "BLOCK_M": 256,
+                "BLOCK_N": 64,
+                "waves_per_eu": 2,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=8,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 128,
+                "BLOCK_N": 128,
+                "waves_per_eu": 2,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=4,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 256,
+                "BLOCK_N": 128,
+                "waves_per_eu": 2,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=8,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 128,
+                "BLOCK_N": 64,
+                "waves_per_eu": 3,
+                "PRE_LOAD_V": True,
+            },
+            num_stages=1,
+            num_warps=4,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 128,
+                "BLOCK_N": 64,
+                "waves_per_eu": 3,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=4,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 64,
+                "BLOCK_N": 64,
+                "waves_per_eu": 4,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=8,
+        ),
+        triton.Config(
+            {
+                "BLOCK_M": 32,
+                "BLOCK_N": 32,
+                "waves_per_eu": 4,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=8,
+        ),
+        # TODO: This config fails with head_size not pow2 with data mismatches.
+        #    triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
+        #                   'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
+        triton.Config(
+            {
+                "BLOCK_M": 16,
+                "BLOCK_N": 16,
+                "waves_per_eu": 1,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=4,
+        ),
+    ],
+    key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
+)
+@triton.jit
+def attn_fwd(
+    Q,
+    K,
+    V,
+    bias,
+    sm_scale,
+    L,
+    Out,
+    stride_qz,
+    stride_qh,
+    stride_qm,
+    stride_qk,
+    stride_kz,
+    stride_kh,
+    stride_kn,
+    stride_kk,
+    stride_vz,
+    stride_vh,
+    stride_vk,
+    stride_vn,
+    stride_oz,
+    stride_oh,
+    stride_om,
+    stride_on,
+    stride_bz,
+    stride_bh,
+    stride_bm,
+    stride_bn,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    dropout_p,
+    philox_seed,
+    philox_offset_base,
+    encoded_softmax,
+    hq,
+    hk,
+    ACTUAL_BLOCK_DMODEL: tl.constexpr,
+    MAX_SEQLENS_Q: tl.constexpr,
+    MAX_SEQLENS_K: tl.constexpr,
+    VARLEN: tl.constexpr,
+    IS_CAUSAL: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_DMODEL: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    PRE_LOAD_V: tl.constexpr,
+    BIAS_TYPE: tl.constexpr,
+    ENABLE_DROPOUT: tl.constexpr,
+    RETURN_ENCODED_SOFTMAX: tl.constexpr,
+):
+    start_m = tl.program_id(0)
+    off_h_q = tl.program_id(1)
+    off_z = tl.program_id(2)
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    offs_n = tl.arange(0, BLOCK_N)
+    if VARLEN:
+        cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
+        cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
+        seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
+        # We have a one-size-fits-all grid in id(0). Some seqlens might be too
+        # small for all start_m so for those we return early.
+        if start_m * BLOCK_M > seqlen_q:
+            return
+        cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
+        cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
+        seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
+    else:
+        cu_seqlens_q_start = 0
+        cu_seqlens_k_start = 0
+        seqlen_q = MAX_SEQLENS_Q
+        seqlen_k = MAX_SEQLENS_K
+
+    # Now we compute whether we need to exit early due to causal masking.
+    # This is because for seqlen_q > seqlen_k, M rows of the attn scores
+    # are completely masked, resulting in 0s written to the output, and
+    # inf written to LSE. We don't need to do any GEMMs in this case.
+    # This block of code determines what N is, and if this WG is operating
+    # on those M rows.
+    n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
+    if IS_CAUSAL:
+        # If seqlen_q == seqlen_k, the attn scores are a square matrix.
+        # If seqlen_q != seqlen_k, attn scores are rectangular which means
+        # the causal mask boundary is bottom right aligned, and ends at either
+        # the top edge (seqlen_q < seqlen_k) or left edge.
+        # This captures the decrease in n_blocks if we have a rectangular attn
+        # matrix
+        n_blocks_seqlen = cdiv_fn(
+            (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
+        # This is what adjusts the block_max for the current WG, only
+        # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
+        n_blocks = min(n_blocks, n_blocks_seqlen)
+        # If we have no blocks after adjusting for seqlen deltas, this WG is
+        # part of the blocks that are all 0. We exit early.
+        if n_blocks <= 0:
+            o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
+                        off_h_q * stride_oh)
+            O_block_ptr = tl.make_block_ptr(
+                base=Out + o_offset,
+                shape=(seqlen_q, BLOCK_DMODEL),
+                strides=(stride_om, stride_on),
+                offsets=(start_m * BLOCK_M, 0),
+                block_shape=(BLOCK_M, BLOCK_DMODEL),
+                order=(1, 0),
+            )
+            acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
+            # We still need to write 0s to the result
+            # tl.store(O_block_ptr,
+            # acc.to(Out.type.element_ty), boundary_check=(0,1))
+            # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
+            #          + offs_m
+            # We store inf to LSE, not -inf because in the bwd pass,
+            # we subtract this
+            # from qk which makes it -inf, such that exp(qk - inf) = 0
+            # for these masked blocks.
+            # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
+            # tl.store(l_ptrs, l)
+            # TODO: Should dropout and return encoded softmax be handled here?
+            return
+
+    is_mqa = hq != hk
+    if is_mqa:  # noqa: SIM108
+        off_h_k = off_h_q % hk
+    else:
+        off_h_k = off_h_q
+    n_extra_tokens = 0
+    if seqlen_k < BLOCK_N:
+        n_extra_tokens = BLOCK_N - seqlen_k
+    elif seqlen_k % BLOCK_N:
+        n_extra_tokens = seqlen_k % BLOCK_N
+    padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
+
+    # Compute pointers for all the tensors used in this kernel.
+    q_offset = (off_z * stride_qz + off_h_q * stride_qh +
+                cu_seqlens_q_start * stride_qm)
+    Q_block_ptr = tl.make_block_ptr(
+        base=Q + q_offset,
+        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
+        strides=(stride_qm, stride_qk),
+        offsets=(start_m * BLOCK_M, 0),
+        block_shape=(BLOCK_M, BLOCK_DMODEL),
+        order=(1, 0),
+    )
+    k_offset = (off_z * stride_kz + off_h_k * stride_kh +
+                cu_seqlens_k_start * stride_kn)
+    K_block_ptr = tl.make_block_ptr(
+        base=K + k_offset,
+        shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
+        strides=(stride_kk, stride_kn),
+        offsets=(0, 0),
+        block_shape=(BLOCK_DMODEL, BLOCK_N),
+        order=(0, 1),
+    )
+    v_offset = (off_z * stride_vz + off_h_k * stride_vh +
+                cu_seqlens_k_start * stride_vk)
+    V_block_ptr = tl.make_block_ptr(
+        base=V + v_offset,
+        shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
+        strides=(stride_vk, stride_vn),
+        offsets=(0, 0),
+        block_shape=(BLOCK_N, BLOCK_DMODEL),
+        order=(1, 0),
+    )
+    if BIAS_TYPE != 0:
+        bias_ptr = tl.make_block_ptr(
+            base=bias + off_h_q * stride_bh,
+            shape=(seqlen_q, seqlen_k),
+            strides=(stride_bm, stride_bn),
+            offsets=(start_m * BLOCK_M, 0),
+            block_shape=(BLOCK_M, BLOCK_N),
+            order=(1, 0),
+        )
+    else:
+        bias_ptr = None
+    if ENABLE_DROPOUT:
+        batch_philox_offset = philox_offset_base \
+                              + (off_z * hq + off_h_q) \
+                              * seqlen_q * seqlen_k
+    else:
+        batch_philox_offset = 0
+    # We can ask to return the dropout mask without actually doing any dropout.
+    # In this case, we return an invalid pointer so indicate the mask is not i
+    # valid.
+    # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
+    if RETURN_ENCODED_SOFTMAX:
+        encoded_softmax_block_ptr = tl.make_block_ptr(
+            base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
+            shape=(seqlen_q, seqlen_k),
+            strides=(seqlen_k, 1),
+            offsets=(start_m * BLOCK_M, 0),
+            block_shape=(BLOCK_M, BLOCK_N),
+            order=(1, 0),
+        )
+    else:
+        encoded_softmax_block_ptr = 0
+    # initialize pointer to m and l
+    m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
+    l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+    # scale sm_scale by log_2(e) and use 2^x in the loop as we do not
+    # have native e^x support in HW.
+    qk_scale = sm_scale * 1.44269504089
+    # Q is loaded once at the beginning and shared by all N blocks.
+    q = load_fn(Q_block_ptr, True, padded_head, "zero")
+    q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
+
+    # Here we compute how many full and masked blocks we have.
+    padded_block_k = n_extra_tokens != 0
+    is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
+    if IS_CAUSAL:
+        # There are always at least BLOCK_M // BLOCK_N masked blocks.
+        # Additionally there might be one more due to dissimilar seqlens.
+        masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
+    else:
+        # Padding on Q does not need to be masked in the FA loop.
+        masked_blocks = padded_block_k
+    # if IS_CAUSAL, not is_modulo_mn does not always result in an additional
+    # block. In this case we might exceed n_blocks so pick the min.
+    masked_blocks = min(masked_blocks, n_blocks)
+    n_full_blocks = n_blocks - masked_blocks
+    block_min = 0
+    block_max = n_blocks * BLOCK_N
+    # Compute for full blocks. Here we set causal to false regardless of its
+    # value because there is no masking. Similarly we do not need padding.
+    if n_full_blocks > 0:
+        block_max = (n_blocks - masked_blocks) * BLOCK_N
+        acc, l_i, m_i = _attn_fwd_inner(
+            acc,
+            l_i,
+            m_i,
+            q,
+            K_block_ptr,
+            V_block_ptr,
+            start_m,
+            seqlen_k,
+            dropout_p,
+            philox_seed,
+            batch_philox_offset,
+            encoded_softmax_block_ptr,
+            # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
+            block_min,
+            block_max,
+            0,
+            0,
+            0,
+            bias_ptr,
+            # IS_CAUSAL, ....
+            False,
+            BLOCK_M,
+            BLOCK_DMODEL,
+            BLOCK_N,
+            offs_m,
+            offs_n,
+            # _, MASK_STEPS, ...
+            PRE_LOAD_V,
+            False,
+            ENABLE_DROPOUT,
+            RETURN_ENCODED_SOFTMAX,
+            padded_head,
+        )
+        block_min = block_max
+        block_max = n_blocks * BLOCK_N
+
+    tl.debug_barrier()
+    # Remaining blocks, if any, are full / not masked.
+    if masked_blocks > 0:
+        offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
+        K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
+        V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
+        if bias_ptr is not None:
+            bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
+        if RETURN_ENCODED_SOFTMAX:
+            encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
+                                                   (0, n_full_blocks))
+        acc, l_i, m_i = _attn_fwd_inner(
+            acc,
+            l_i,
+            m_i,
+            q,
+            K_block_ptr,
+            V_block_ptr,
+            start_m,
+            seqlen_k,
+            dropout_p,
+            philox_seed,
+            batch_philox_offset,
+            encoded_softmax_block_ptr,
+            block_min,
+            block_max,
+            offs_n_causal,
+            masked_blocks,
+            n_extra_tokens,
+            bias_ptr,
+            IS_CAUSAL,
+            BLOCK_M,
+            BLOCK_DMODEL,
+            BLOCK_N,
+            offs_m,
+            offs_n,
+            # _, MASK_STEPS, ...
+            PRE_LOAD_V,
+            True,
+            ENABLE_DROPOUT,
+            RETURN_ENCODED_SOFTMAX,
+            padded_head,
+        )
+    # epilogue
+    acc = acc / l_i[:, None]
+    if ENABLE_DROPOUT:
+        acc = acc / (1 - dropout_p)
+    # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
+    # then we have one block with a row of all NaNs which come from computing
+    # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
+    # and store 0s where there are NaNs as these rows should've been zeroed out.
+    end_m_idx = (start_m + 1) * BLOCK_M
+    start_m_idx = start_m * BLOCK_M
+    causal_start_idx = seqlen_q - seqlen_k
+    acc = acc.to(Out.type.element_ty)
+    if IS_CAUSAL:  # noqa: SIM102
+        if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
+            out_mask_boundary = tl.full((BLOCK_DMODEL, ),
+                                        causal_start_idx,
+                                        dtype=tl.int32)
+            mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
+            out_ptrs_mask = (mask_m_offsets[:, None] >=
+                             out_mask_boundary[None, :])
+            z = 0.0
+            acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
+    # write back LSE
+    # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
+    # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
+    # few rows. This is only true for the last M block. For others,
+    # overflow_size will be -ve
+    # overflow_size = end_m_idx - seqlen_q
+    # if overflow_size > 0:
+    #    boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
+    #    # This is a > check because mask being 0 blocks the store.
+    #    l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
+    #    tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
+    # else:
+    #    tl.store(l_ptrs, m_i + tl.math.log2(l_i))
+
+    # write back O
+    o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
+                off_h_q * stride_oh)
+    O_block_ptr = tl.make_block_ptr(
+        base=Out + o_offset,
+        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
+        strides=(stride_om, stride_on),
+        offsets=(start_m * BLOCK_M, 0),
+        block_shape=(BLOCK_M, BLOCK_DMODEL),
+        order=(1, 0),
+    )
+    # Need boundary check on this to make sure the padding from the
+    # Q and KV tensors in both dims are not part of what we store back.
+    # TODO: Do the boundary check optionally.
+    tl.store(O_block_ptr, acc, boundary_check=(0, 1))
+
+
+def check_args(
+    q,
+    k,
+    v,
+    o,
+    varlen=True,
+    max_seqlens=None,
+    cu_seqlens_q=None,
+    cu_seqlens_k=None,
+):
+    assert q.dim() == k.dim() and q.dim() == v.dim()
+    if varlen:
+        assert q.dim() == 3
+        total_q, nheads_q, head_size = q.shape
+        total_k, nheads_k, _ = k.shape
+        assert cu_seqlens_q is not None
+        assert cu_seqlens_k is not None
+        assert len(cu_seqlens_q) == len(cu_seqlens_k)
+    else:
+        assert q.dim() == 4
+        batch, nheads_q, seqlen_q, head_size = q.shape
+        _, nheads_k, seqlen_k, _ = k.shape
+        assert max_seqlens > 0
+    assert k.shape == v.shape
+    assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
+    # TODO: Change assert if we support qkl f8 and v f16
+    assert q.dtype == k.dtype and q.dtype == v.dtype
+    assert head_size <= 256
+    assert o.shape == q.shape
+    assert (nheads_q % nheads_k) == 0
+
+
+class _attention(torch.autograd.Function):
+
+    @staticmethod
+    def forward(
+        ctx,
+        q,
+        k,
+        v,
+        o,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlens_q,
+        max_seqlens_k,
+        causal=False,
+        sm_scale=1.0,
+        bias=None,
+    ):
+        if o is None:
+            o = torch.empty_like(q, dtype=v.dtype)
+
+        check_args(
+            q,
+            k,
+            v,
+            o,
+            varlen=True,
+            cu_seqlens_q=cu_seqlens_q,
+            cu_seqlens_k=cu_seqlens_k,
+        )
+        if True:  # varlen
+            total_q, nheads_q, head_size = q.shape
+            total_k, nheads_k, _ = k.shape
+            batch = len(cu_seqlens_q) - 1
+            q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
+            k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
+            v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
+            o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
+        else:
+            batch, seqlen_q, nheads_q, head_size = q.shape
+            _, seqlen_k, nheads_k, _ = k.shape
+            q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
+            k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
+            v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
+            o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
+
+        # Get closest power of 2 over or equal to 32.
+        unpadded_head_dims = {32, 64, 128, 256}
+        if head_size not in unpadded_head_dims:
+            padded_d_model = None
+            for i in unpadded_head_dims:
+                if i > head_size:
+                    padded_d_model = i
+                    break
+            assert padded_d_model is not None
+        else:
+            padded_d_model = head_size
+
+        grid = lambda META: (
+            triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
+            nheads_q,
+            batch,
+        )
+
+        encoded_softmax = None
+
+        # Seed the RNG so we get reproducible results for testing.
+        philox_seed = 0x1BF52
+        philox_offset = 0x1D4B42
+
+        if bias is not None:
+            bias_strides = (
+                bias.stride(0),
+                bias.stride(1),
+                bias.stride(2),
+                bias.stride(3),
+            )
+        else:
+            bias_strides = (0, 0, 0, 0)
+
+        attn_fwd[grid](
+            q,
+            k,
+            v,
+            bias,
+            sm_scale,
+            None,
+            o,
+            *q_strides,
+            *k_strides,
+            *v_strides,
+            *o_strides,
+            *bias_strides,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            dropout_p=0.0,
+            philox_seed=philox_seed,
+            philox_offset_base=philox_offset,
+            encoded_softmax=encoded_softmax,
+            hq=nheads_q,
+            hk=nheads_k,
+            ACTUAL_BLOCK_DMODEL=head_size,
+            MAX_SEQLENS_Q=max_seqlens_q,
+            MAX_SEQLENS_K=max_seqlens_k,
+            IS_CAUSAL=causal,
+            VARLEN=True,
+            BLOCK_DMODEL=padded_d_model,
+            BIAS_TYPE=0 if bias is None else 1,
+            ENABLE_DROPOUT=False,
+            RETURN_ENCODED_SOFTMAX=False,
+        )
+
+        ctx.grid = grid
+        ctx.sm_scale = sm_scale
+        ctx.BLOCK_DMODEL = head_size
+        ctx.causal = causal
+        ctx.dropout_p = 0.0
+        ctx.philox_seed = philox_seed
+        ctx.philox_offset = philox_offset
+        ctx.encoded_softmax = encoded_softmax
+        ctx.return_encoded_softmax = False
+        return o, encoded_softmax
+
+
+triton_attention = _attention.apply

+ 74 - 0
aphrodite/attention/selector.py

@@ -0,0 +1,74 @@
+import enum
+from functools import lru_cache
+from typing import Type
+
+import torch
+from loguru import logger
+
+from aphrodite.attention.backends.abstract import AttentionBackend
+from aphrodite.common.utils import is_cpu, is_hip
+
+
+class _Backend(enum.Enum):
+    FLASH_ATTN = enum.auto()
+    XFORMERS = enum.auto()
+    ROCM_FLASH = enum.auto()
+    TORCH_SDPA = enum.auto()
+
+
+@lru_cache(maxsize=None)
+def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
+    backend = _which_attn_to_use(dtype)
+    if backend == _Backend.FLASH_ATTN:
+        logger.info("Using FlashAttention backend.")
+        from aphrodite.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
+        return FlashAttentionBackend
+    elif backend == _Backend.XFORMERS:
+        logger.info("Using XFormers backend.")
+        from aphrodite.attention.backends.xformers import XFormersBackend  # noqa: F501
+        return XFormersBackend
+    elif backend == _Backend.ROCM_FLASH:
+        logger.info("Using ROCm FlashAttention backend.")
+        from aphrodite.attention.backends.rocm_flash_attn import (  # noqa: F401
+            ROCmFlashAttentionBackend)
+        return ROCmFlashAttentionBackend
+    elif backend == _Backend.TORCH_SDPA:
+        logger.info("Using Torch SDPA backend.")
+        from aphrodite.attention.backends.sdpa import TorchSDPABackend
+        return TorchSDPABackend
+    else:
+        raise ValueError("Invalid attention backend.")
+
+
+def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
+    """Returns which flash attention backend to use."""
+    if is_cpu():
+        return _Backend.TORCH_SDPA
+
+    if is_hip():
+        # AMD GPUs.
+        if torch.cuda.get_device_capability()[0] != 9:
+            # not Instinct series GPUs.
+            logger.info("flash_atten is not supported on NAVI GPUs.")
+        return _Backend.ROCM_FLASH
+
+    # NVIDIA GPUs.
+    if torch.cuda.get_device_capability()[0] < 8:
+        # Volta and Turing NVIDIA GPUs.
+        logger.info("Cannot use FlashAttention backend for Volta and Turing "
+                    "GPUs.")
+        return _Backend.XFORMERS
+
+    if dtype not in (torch.float16, torch.bfloat16):
+        logger.info("Cannot use FlashAttention backend for dtype other than "
+                    "torch.float16 or torch.bfloat16.")
+        return _Backend.XFORMERS
+
+    try:
+        import flash_attn  # noqa: F401
+    except ImportError:
+        logger.info(
+            "Cannot use FlashAttention backend because the flash_attn package "
+            "is not found. Please install it for better performance.")
+        return _Backend.XFORMERS
+    return _Backend.FLASH_ATTN

+ 429 - 60
aphrodite/common/config.py

@@ -1,5 +1,6 @@
+import enum
 from typing import TYPE_CHECKING, Optional, Union, ClassVar
-from dataclasses import dataclass
+from dataclasses import dataclass, fields
 import os
 from packaging.version import Version
 from loguru import logger
@@ -8,8 +9,8 @@ import json
 import torch
 from transformers import PretrainedConfig
 
-from aphrodite.transformers_utils.config import get_config
-from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
+from aphrodite.transformers_utils.config import get_config, get_hf_text_config
+from aphrodite.common.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
                                     get_nvcc_cuda_version)
 
 if TYPE_CHECKING:
@@ -62,6 +63,11 @@ class ModelConfig:
         load_in_8bit: Whether to load the FP16 model in 8bit format. Slower
             than load_in_smooth in terms of throughput.
         load_in_smooth: Whether to load the FP16 model in smoothquant format.
+        quantization_param_path: Path to JSON file containing scaling factors.
+            Used to load KV cache scaling factors into the model when KV cache
+            type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also 
+            be used to load activation and weight scaling factors when the 
+            model dtype is FP8_E4M3 on ROCm.
         enforce_eager: Whether to enforce eager execution. If True, we will
             disable CUDA graph and always execute the model in eager mode.
             If False, we will use CUDA graph and eager execution in hybrid.
@@ -89,6 +95,7 @@ class ModelConfig:
         load_in_4bit: bool = False,
         load_in_8bit: bool = False,
         load_in_smooth: bool = False,
+        quantization_param_path: Optional[str] = None,
         enforce_eager: bool = True,
         max_context_len_to_capture: Optional[int] = None,
         max_log_probs: int = 10,
@@ -107,6 +114,7 @@ class ModelConfig:
         self.load_in_4bit = load_in_4bit
         self.load_in_8bit = load_in_8bit
         self.load_in_smooth = load_in_smooth
+        self.quantization_param_path = quantization_param_path
         self.enforce_eager = enforce_eager
         self.max_context_len_to_capture = max_context_len_to_capture
         self.max_log_probs = max_log_probs
@@ -128,8 +136,9 @@ class ModelConfig:
 
         self.hf_config = get_config(self.model, trust_remote_code, revision,
                                     code_revision)
-        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
-        self.max_model_len = _get_and_verify_max_len(self.hf_config,
+        self.hf_text_config = get_hf_text_config(self.hf_config)
+        self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
+        self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
                                                      max_model_len)
         self._verify_load_format()
         self._verify_tokenizer_mode()
@@ -158,7 +167,9 @@ class ModelConfig:
 
         # TODO: Remove this check once HF updates the pt weights of Mixtral.
         architectures = getattr(self.hf_config, "architectures", [])
-        if "MixtralForCausalLM" in architectures and load_format == "pt":
+        # architectures can be None instead of []
+        if architectures and "MixtralForCausalLM" in architectures \
+            and load_format == "pt":
             raise ValueError(
                 "Currently, the 'pt' format is not supported for Mixtral. "
                 "Please use the 'safetensors' format instead. ")
@@ -174,8 +185,8 @@ class ModelConfig:
 
     def _verify_quantization(self) -> None:
         supported_quantization = [
-            "aqlm", "awq", "bnb", "exl2", "gguf", "gptq", "quip", "squeezellm",
-            "marlin"
+            "aqlm", "awq", "bnb", "eetq", "exl2", "gguf", "gptq", "quip",
+            "squeezellm", "marlin"
         ]
         rocm_not_supported_quantization = ["aqlm", "awq", "bnb", "quip"]
         if self.quantization is not None:
@@ -291,7 +302,7 @@ class ModelConfig:
         self,
         parallel_config: "ParallelConfig",
     ) -> None:
-        total_num_attention_heads = self.hf_config.num_attention_heads
+        total_num_attention_heads = self.hf_text_config.num_attention_heads
         tensor_parallel_size = parallel_config.tensor_parallel_size
         if total_num_attention_heads % tensor_parallel_size != 0:
             raise ValueError(
@@ -299,7 +310,7 @@ class ModelConfig:
                 " must be divisible by tensor parallel size "
                 f"({tensor_parallel_size}).")
 
-        total_num_hidden_layers = self.hf_config.num_hidden_layers
+        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
         pipeline_parallel_size = parallel_config.pipeline_parallel_size
         if total_num_hidden_layers % pipeline_parallel_size != 0:
             raise ValueError(
@@ -308,22 +319,23 @@ class ModelConfig:
                 f"({pipeline_parallel_size}).")
 
     def get_sliding_window(self) -> Optional[int]:
-        if (hasattr(self.hf_config, "use_sliding_window")
-                and not self.hf_config.use_sliding_window):
+        if (hasattr(self.hf_text_config, "use_sliding_window")
+                and not self.hf_text_config.use_sliding_window):
             return None
-        return getattr(self.hf_config, "sliding_window", None)
+        return getattr(self.hf_text_config, "sliding_window", None)
 
     def get_vocab_size(self) -> int:
-        return self.hf_config.vocab_size
+        return self.hf_text_config.vocab_size
 
     def get_hidden_size(self) -> int:
-        return self.hf_config.hidden_size
+        return self.hf_text_config.hidden_size
 
     def get_head_size(self) -> int:
         if hasattr(self.hf_config, "head_dim"):
             return self.hf_config.head_dim
         # FIXME: This may not be true for all models.
-        return self.hf_config.hidden_size // self.hf_config.num_attention_heads
+        return (self.hf_text_config.hidden_size //
+                self.hf_text_config.num_attention_heads)
 
     def get_total_num_kv_heads(self) -> int:
         """Returns the total number of KV heads."""
@@ -335,12 +347,17 @@ class ModelConfig:
         new_decoder_arch_falcon = (
             self.hf_config.model_type in falcon_model_types
             and getattr(self.hf_config, "new_decoder_architecture", False))
-        if not new_decoder_arch_falcon and getattr(self.hf_config,
+        if not new_decoder_arch_falcon and getattr(self.hf_text_config,
                                                    "multi_query", False):
             # Multi-query attention, only one KV head.
             # Currently, tensor parallelism is not supported in this case.
             return 1
 
+        # For DBRX and MPT
+        if self.hf_config.model_type in ["dbrx", "mpt"]:
+            return getattr(self.hf_config.attn_config, "kv_n_heads",
+                           self.hf_config.num_attention_heads)
+
         attributes = [
             # For Falcon:
             "n_head_kv",
@@ -351,13 +368,13 @@ class ModelConfig:
             "multi_query_group_num",
         ]
         for attr in attributes:
-            num_kv_heads = getattr(self.hf_config, attr, None)
+            num_kv_heads = getattr(self.hf_text_config, attr, None)
             if num_kv_heads is not None:
                 return num_kv_heads
 
         # For non-grouped-query attention models, the number of KV heads is
         # equal to the number of attention heads.
-        return self.hf_config.num_attention_heads
+        return self.hf_text_config.num_attention_heads
 
     def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
         """Returns the number of KV heads per GPU."""
@@ -370,7 +387,7 @@ class ModelConfig:
                    total_num_kv_heads // parallel_config.tensor_parallel_size)
 
     def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
-        total_num_hidden_layers = self.hf_config.num_hidden_layers
+        total_num_hidden_layers = self.hf_text_config.num_hidden_layers
         return total_num_hidden_layers // parallel_config.pipeline_parallel_size
 
 
@@ -385,6 +402,8 @@ class CacheConfig:
         cache_dtype: Data Type for KV cache storage.
         cache_quant_params_path: Path to the scales and zero points
             of KV cache quantization when cache_dtype is int8.
+        num_gpu_blocks_override: Number of GPU blocks to use. This overrides
+            the profiled num_gpu_blocks if specified. Does nothing if None.
     """
 
     def __init__(
@@ -394,12 +413,14 @@ class CacheConfig:
         swap_space: int,
         cache_dtype: str,
         # cache_quant_params_path: Optional[str] = None,
+        num_gpu_blocks_override: Optional[int] = None,
         sliding_window: Optional[int] = None,
         context_shift: bool = False,
     ) -> None:
         self.block_size = block_size
         self.gpu_memory_utilization = gpu_memory_utilization
         self.swap_space_bytes = swap_space * _GB
+        self.num_gpu_blocks_override = num_gpu_blocks_override
         self.cache_dtype = cache_dtype
         self.sliding_window = sliding_window
         # self.cache_quant_params_path = cache_quant_params_path
@@ -426,21 +447,20 @@ class CacheConfig:
         if self.cache_dtype == "auto":
             # if self.cache_dtype in ["auto", "int8"]:
             pass
-        elif self.cache_dtype == "fp8_e5m2":
-            if is_hip():
-                raise NotImplementedError(
-                    "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
-            nvcc_cuda_version = get_nvcc_cuda_version()
-            if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
-                raise ValueError(
-                    "FP8 is not supported when cuda version is lower than 11.8."
-                )
+        elif self.cache_dtype == "fp8":
+            if not is_hip():
+                nvcc_cuda_version = get_nvcc_cuda_version()
+                if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
+                    raise ValueError(
+                        "FP8 is not supported when cuda version is"
+                        "lower than 11.8.")
             logger.info(
-                "Using fp8_e5m2 data type to store kv cache. It reduces "
-                "the GPU memory footprint and boosts the performance. "
-                "But it may cause slight accuracy drop. "
-                "Currently we only support fp8 without scaling factors and "
-                "use e5m2 as a default format.")
+                "Using fp8 data type to store kv cache. It reduces the GPU "
+                "memory footprint and boosts the performance. "
+                "But it may cause slight accuracy drop without scaling "
+                "factors. FP8_E5M2 (without scaling) is only supported on "
+                "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
+                "is instead supported for common inference criteria.")
         else:
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 
@@ -546,15 +566,7 @@ class ParallelConfig:
         placement_group: Optional["PlacementGroup"] = None,
     ) -> None:
         self.pipeline_parallel_size = pipeline_parallel_size
-        if is_neuron():
-            # For Neuron device support, here we assign TP=1 to avoid sharding
-            # within Aphrodite directly.
-            # Transformer-neuronx would take neuron_tp_degree attribute, and
-            # distribute the workload to multiple NeuronCores.
-            self.tensor_parallel_size = 1
-            self.neuron_tp_degree = tensor_parallel_size
-        else:
-            self.tensor_parallel_size = tensor_parallel_size
+        self.tensor_parallel_size = tensor_parallel_size
         self.worker_use_ray = worker_use_ray
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.disable_custom_all_reduce = disable_custom_all_reduce
@@ -563,8 +575,7 @@ class ParallelConfig:
         self.placement_group = placement_group
 
         self.world_size = pipeline_parallel_size * self.tensor_parallel_size
-        # Ray worker is not supported for Neuron backend.
-        if self.world_size > 1 and not is_neuron():
+        if self.world_size > 1:
             self.worker_use_ray = True
         self._verify_args()
 
@@ -587,15 +598,6 @@ class ParallelConfig:
             raise ValueError("Unable to use nsight profiling unless workers "
                              "run with Ray.")
 
-        # FIXME: Fix the stability issues and re-enable the custom
-        # all-reduce kernel.
-        if not self.disable_custom_all_reduce and self.world_size > 1:
-            self.disable_custom_all_reduce = True
-            logger.info(
-                "Custom all-reduce kernels are temporarily disabled due to "
-                "stability issues. We will re-enable them once the issues are "
-                "resolved.")
-
 
 class SchedulerConfig:
     """Scheduler configuration.
@@ -607,6 +609,18 @@ class SchedulerConfig:
             iteration.
         max_model_len: Maximum length of a sequence (including prompt
             and generated text).
+        use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
+        num_lookahead_slots: The number of slots to allocate per sequence per
+            step, beyond the known token ids. This is used in speculative
+            decoding to store KV activations of tokens which may or may not be
+            accepted.
+        delay_factor: Apply a delay (of delay factor multiplied by previous
+            prompt latency) before scheduling the next prompt.
+        policy: Policy of sequence scheduling (`fcfs` or `reorder`).
+        reorder_window: Allowed reorder window size (in sec) for `reorder`
+            policy.
+        enable_chunked_prefill: If True, prefill requests can be chunked
+            based on the remaining max_num_batched_tokens.
     """
 
     def __init__(
@@ -614,19 +628,38 @@ class SchedulerConfig:
         max_num_batched_tokens: Optional[int],
         max_num_seqs: int,
         max_model_len: int,
+        use_v2_block_manager: bool = False,
+        num_lookahead_slots: int = 0,
+        delay_factor: float = 0.0,
+        policy: str = "fcfs",
+        reorder_window: float = 0.0,
+        enable_chunked_prefill: bool = False,
     ) -> None:
         if max_num_batched_tokens is not None:
             self.max_num_batched_tokens = max_num_batched_tokens
         else:
-            # If max_model_len is too short, use 2048 as the default value for
-            # higher throughput.
-            self.max_num_batched_tokens = max(max_model_len, 2048)
+            if enable_chunked_prefill:
+                # For chunked prefill, choose the well-tuned batch size.
+                self.max_num_batched_tokens = 768
+            else:
+                # If max_model_len is too short, use 2048 as the default value
+                # for higher throughput.
+                self.max_num_batched_tokens = max(max_model_len, 2048)
+        if enable_chunked_prefill:
+            logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
         self.max_num_seqs = max_num_seqs
         self.max_model_len = max_model_len
+        self.use_v2_block_manager = use_v2_block_manager
+        self.num_lookahead_slots = num_lookahead_slots
+        self.delay_factor = delay_factor
+        self.policy = policy
+        self.reorder_window = reorder_window
+        self.chunked_prefill_enabled = enable_chunked_prefill
         self._verify_args()
 
     def _verify_args(self) -> None:
-        if self.max_num_batched_tokens < self.max_model_len:
+        if (self.max_num_batched_tokens < self.max_model_len
+                and not self.chunked_prefill_enabled):
             raise ValueError(
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                 f"smaller than max_model_len ({self.max_model_len}). "
@@ -639,6 +672,17 @@ class SchedulerConfig:
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                 "be greater than or equal to max_num_seqs "
                 f"({self.max_num_seqs}).")
+        if self.reorder_window < 0:
+            raise ValueError(f"reorder_window ({self.reorder_window}) must "
+                             "be not be negative.")
+        if self.reorder_window != 0 and self.policy != 'reorder':
+            raise ValueError("fcfs policy doesn't support reorder_window "
+                             f"({self.reorder_window}).")
+        if self.num_lookahead_slots < 0:
+            raise ValueError(
+                "num_lookahead_slots "
+                f"({self.num_lookahead_slots}) must be greater than or "
+                "equal to 0.")
 
 
 class DeviceConfig:
@@ -650,6 +694,8 @@ class DeviceConfig:
                 self.device_type = "cuda"
             elif is_neuron():
                 self.device_type = "neuron"
+            elif is_cpu():
+                self.device_type = "cpu"
             else:
                 raise RuntimeError("No supported device detected.")
         else:
@@ -663,9 +709,241 @@ class DeviceConfig:
             # Set device with device type
             self.device = torch.device(self.device_type)
 
+
+class SpeculativeConfig:
+    """Configuration for speculative decoding.
+    The configuration is currently specialized to draft-model speculative
+    decoding with top-1 proposals.
+    """
+
+    @staticmethod
+    def maybe_create_spec_config(
+        target_model_config: ModelConfig,
+        target_parallel_config: ParallelConfig,
+        target_dtype: str,
+        speculative_model: Optional[str],
+        num_speculative_tokens: Optional[int],
+        speculative_max_model_len: Optional[int],
+        enable_chunked_prefill: bool,
+        use_v2_block_manager: bool,
+        ngram_prompt_lookup_max: Optional[int],
+        ngram_prompt_lookup_min: Optional[int],
+    ) -> Optional["SpeculativeConfig"]:
+        """Create a SpeculativeConfig if possible, else return None.
+        This function attempts to create a SpeculativeConfig object based on the
+        provided parameters. If the necessary conditions are met, it returns an
+        instance of SpeculativeConfig. Otherwise, it returns None.
+        Args:
+            target_model_config (ModelConfig): The configuration of the target
+                model.
+            target_parallel_config (ParallelConfig): The parallel configuration
+                for the target model.
+            target_dtype (str): The data type used for the target model.
+            speculative_model (Optional[str]): The name of the speculative
+                model, if provided.
+            num_speculative_tokens (Optional[int]): The number of speculative
+                tokens, if provided.
+            speculative_max_model_len (Optional[int]): The maximum model len of
+                the speculative model. Used when testing the ability to skip
+                speculation for some sequences.
+            enable_chunked_prefill (bool): Whether Aphrodite is configured to
+                use chunked prefill or not. Used for raising an error since its
+                not yet compatible with spec decode.
+            use_v2_block_manager (bool): Whether Aphrodite is configured to
+                use the v2 block manager or not. Used for raising an error
+                since the v2 block manager is required with spec decode.
+            ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
+                window, if provided.
+            ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
+                window, if provided.
+        Returns:
+            Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
+                the necessary conditions are met, else None.
+        """
+
+        if (speculative_model is None and num_speculative_tokens is None):
+            return None
+
+        if speculative_model is not None and num_speculative_tokens is None:
+            raise ValueError(
+                "Expected both speculative_model and "
+                "num_speculative_tokens to be provided, but found "
+                f"{speculative_model=} and {num_speculative_tokens=}.")
+        assert (speculative_model is not None
+                and num_speculative_tokens is not None)
+
+        if enable_chunked_prefill:
+            raise ValueError(
+                "Speculative decoding and chunked prefill are "
+                f"currently mutually exclusive ({enable_chunked_prefill=}).")
+
+        if not use_v2_block_manager:
+            raise ValueError(
+                "Speculative decoding requires usage of the V2 "
+                "block manager. Enable it with --use-v2-block-manager.")
+
+        # TODO: The user should be able to specify revision/quantization/max
+        # model len for the draft model. It is not currently supported.
+        draft_revision = None
+        draft_code_revision = None
+        draft_quantization = None
+
+        if speculative_model == "[ngram]":
+            assert (ngram_prompt_lookup_max is not None
+                    and ngram_prompt_lookup_max > 0)
+            if ngram_prompt_lookup_min is None:
+                ngram_prompt_lookup_min = 0
+            else:
+                assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
+            draft_model_config = target_model_config
+            draft_parallel_config = target_parallel_config
+        else:
+            ngram_prompt_lookup_max = 0
+            ngram_prompt_lookup_min = 0
+            draft_model_config = ModelConfig(
+                model=speculative_model,
+                download_dir=target_model_config.download_dir,
+                load_format=target_model_config.load_format,
+                tokenizer=target_model_config.tokenizer,
+                tokenizer_mode=target_model_config.tokenizer_mode,
+                trust_remote_code=target_model_config.trust_remote_code,
+                dtype=target_model_config.dtype,
+                seed=target_model_config.seed,
+                revision=draft_revision,
+                code_revision=draft_code_revision,
+                tokenizer_revision=target_model_config.tokenizer_revision,
+                max_model_len=None,
+                quantization=draft_quantization,
+                enforce_eager=target_model_config.enforce_eager,
+                max_context_len_to_capture=target_model_config.
+                max_context_len_to_capture,
+                max_log_probs=target_model_config.max_log_probs,
+            )
+
+            draft_model_config.max_model_len = (
+                SpeculativeConfig._maybe_override_draft_max_model_len(
+                    speculative_max_model_len,
+                    draft_model_config.max_model_len,
+                    target_model_config.max_model_len,
+                ))
+
+            draft_parallel_config = (
+                SpeculativeConfig.create_draft_parallel_config(
+                    target_parallel_config))
+
+        return SpeculativeConfig(
+            draft_model_config,
+            draft_parallel_config,
+            num_speculative_tokens,
+            ngram_prompt_lookup_max,
+            ngram_prompt_lookup_min,
+        )
+
+    @staticmethod
+    def _maybe_override_draft_max_model_len(
+        speculative_max_model_len: Optional[int],
+        draft_max_model_len: int,
+        target_max_model_len: int,
+    ) -> int:
+        """Determine the max sequence len for the draft model. This is usually
+        the draft_max_model_len, but may be the target_max_model_len if it is
+        less than the draft_max_model_len, or may be speculative_max_model_len
+        if it is specified.
+        This is necessary so that sequences do not exceed the capacity of the
+        draft model or the target model.
+        speculative_max_model_len is mainly used for testing that sequences can
+        skip speculation.
+        """
+
+        if speculative_max_model_len is not None:
+
+            if speculative_max_model_len > draft_max_model_len:
+                raise ValueError(f"{speculative_max_model_len=} cannot be "
+                                 f"larger than {draft_max_model_len=}")
+
+            if speculative_max_model_len > target_max_model_len:
+                raise ValueError(f"{speculative_max_model_len=} cannot be "
+                                 f"larger than {target_max_model_len=}")
+
+            return speculative_max_model_len
+
+        return min(
+            draft_max_model_len,
+            target_max_model_len,
+        )
+
+    @staticmethod
+    def create_draft_parallel_config(
+            target_parallel_config: ParallelConfig) -> ParallelConfig:
+        """Create a parallel config for use by the draft worker.
+        This is mostly a copy of the target parallel config. In the future the
+        draft worker can have a different parallel strategy, e.g. TP=1.
+        """
+        draft_parallel_config = ParallelConfig(
+            pipeline_parallel_size=target_parallel_config.
+            pipeline_parallel_size,
+            tensor_parallel_size=target_parallel_config.tensor_parallel_size,
+            worker_use_ray=target_parallel_config.worker_use_ray,
+            max_parallel_loading_workers=target_parallel_config.
+            max_parallel_loading_workers,
+            disable_custom_all_reduce=target_parallel_config.
+            disable_custom_all_reduce,
+            tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
+            ray_workers_use_nsight=target_parallel_config.
+            ray_workers_use_nsight,
+            placement_group=target_parallel_config.placement_group,
+        )
+
+        return draft_parallel_config
+
+    def __init__(
+        self,
+        draft_model_config: ModelConfig,
+        draft_parallel_config: ParallelConfig,
+        num_speculative_tokens: int,
+        ngram_prompt_lookup_max: int,
+        ngram_prompt_lookup_min: int,
+    ):
+        """Create a SpeculativeConfig object.
+        Args:
+            draft_model_config: ModelConfig for the draft model.
+            draft_parallel_config: ParallelConfig for the draft model.
+            num_speculative_tokens: The number of tokens to sample from the
+                draft model before scoring with the target model.
+        """
+        self.draft_model_config = draft_model_config
+        self.draft_parallel_config = draft_parallel_config
+        self.num_speculative_tokens = num_speculative_tokens
+        self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
+        self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
+
+        self._verify_args()
+
+    def _verify_args(self) -> None:
+        if self.num_speculative_tokens <= 0:
+            raise ValueError("Expected num_speculative_tokens to be greater "
+                             f"than zero ({self.num_speculative_tokens}).")
+
+        if self.draft_model_config:
+            self.draft_model_config.verify_with_parallel_config(
+                self.draft_parallel_config)
+
     @property
-    def is_neuron(self):
-        return self.device_type == "neuron"
+    def num_lookahead_slots(self) -> int:
+        """The number of additional slots the scheduler should allocate per
+        step, in addition to the slots allocated for each known token.
+        This is equal to the number of speculative tokens, as each speculative
+        token must be scored.
+        """
+        return self.num_speculative_tokens
+
+    def __repr__(self) -> str:
+        if self.ngram_prompt_lookup_max > 0:
+            draft_model = "[ngram]"
+        else:
+            draft_model = self.draft_model_config.model
+        num_spec_tokens = self.num_speculative_tokens
+        return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
 
 
 @dataclass
@@ -716,6 +994,45 @@ class LoRAConfig:
                 "LoRA is enabled.")
 
 
+@dataclass
+class VisionLanguageConfig:
+    """Configs the input data format and how models should run for
+    vision language models."""
+
+    class ImageInputType(enum.Enum):
+        """Image input type into the vision language model.
+        An image roughly goes through the following transformation:
+        Raw image --> pixel values --> image features --> image embeddings.
+        The difference between different image input types is where the
+        image encoder (pixel values --> image features) is run.
+        Different image input types also correspond to different tensor shapes.
+        For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
+        IMAGE_FEATURES: (1, 576, 1024).
+        """
+        PIXEL_VALUES = enum.auto()
+        IMAGE_FEATURES = enum.auto()
+
+    image_input_type: ImageInputType
+    # The input id corresponding to image token.
+    image_token_id: int
+    # Used for running `run_prefill_max_token`.
+    # For models that support varying resolution, this corresponds to
+    # worst case scenario (biggest supported resolution).
+    image_input_shape: tuple
+    image_feature_size: int
+
+    @classmethod
+    def get_image_input_enum_type(
+            cls, value: str) -> "VisionLanguageConfig.ImageInputType":
+        """Get the image input type from a string."""
+        try:
+            return cls.ImageInputType[value.upper()]
+        except KeyError as e:
+            raise ValueError(f"{value} is not a valid choice. "
+                             f"Expecting to choose from "
+                             f"{[x.name for x in cls.ImageInputType]}.") from e
+
+
 _STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.float16,
     "float16": torch.float16,
@@ -785,6 +1102,8 @@ def _get_and_verify_max_len(
     """Get and verify the model's maximum length."""
     derived_max_model_len = float("inf")
     possible_keys = [
+        # Cohere: needs to prioritize this over "max_position_embeddings"
+        "model_max_length",
         # OPT
         "max_position_embeddings",
         # GPT-2
@@ -802,6 +1121,7 @@ def _get_and_verify_max_len(
         max_len_key = getattr(hf_config, key, None)
         if max_len_key is not None:
             derived_max_model_len = min(derived_max_model_len, max_len_key)
+            break
     if derived_max_model_len == float("inf"):
         if max_model_len is not None:
             # If max_model_len is specified, we use it.
@@ -836,3 +1156,52 @@ def _get_and_verify_max_len(
             "Attempting to use RoPE scaling.")
         derived_max_model_len = max_model_len
     return int(max_model_len)
+
+
+@dataclass
+class DecodingConfig:
+    """Dataclass which contains the decoding strategy of the engine"""
+
+    # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
+    guided_decoding_backend: str = 'outlines'
+
+    def __post_init__(self):
+        valid_guided_backends = ['outlines', 'lm-format-enforcer']
+        backend = self.guided_decoding_backend
+        if backend not in valid_guided_backends:
+            raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
+                             f"must be one of {valid_guided_backends}")
+
+
+@dataclass(frozen=True)
+class EngineConfig:
+    """Dataclass which contains all engine-related configuration. This
+    simplifies passing around the distinct configurations in the codebase.
+    """
+
+    model_config: ModelConfig
+    cache_config: CacheConfig
+    parallel_config: ParallelConfig
+    scheduler_config: SchedulerConfig
+    device_config: DeviceConfig
+    lora_config: Optional[LoRAConfig]
+    vision_language_config: Optional[VisionLanguageConfig]
+    speculative_config: Optional[SpeculativeConfig]
+    decoding_config: Optional[DecodingConfig]
+
+    def __post_init__(self):
+        """Verify configs are valid & consistent with each other.
+        """
+        self.model_config.verify_with_parallel_config(self.parallel_config)
+        self.cache_config.verify_with_parallel_config(self.parallel_config)
+
+        if self.lora_config:
+            self.lora_config.verify_with_model_config(self.model_config)
+            self.lora_config.verify_with_scheduler_config(
+                self.scheduler_config)
+
+    def to_dict(self):
+        """Return the configs as a dictionary, for use in **kwargs.
+        """
+        return dict(
+            (field.name, getattr(self, field.name)) for field in fields(self))

+ 14 - 0
aphrodite/common/logger.py

@@ -78,6 +78,15 @@ def _log_formatter(record: dict):
     return fmt
 
 
+_logged_messages = set()
+
+
+def log_once(level, message, *args, **kwargs):
+    if message not in _logged_messages:
+        _logged_messages.add(message)
+        logger.log(level, message, *args, **kwargs)
+
+
 # Uvicorn log handler
 # Uvicorn log portions inspired from https://github.com/encode/uvicorn/discussions/2027#discussioncomment-6432362
 class UvicornLoggingHandler(logging.Handler):
@@ -117,3 +126,8 @@ def setup_logger():
         format=_log_formatter,
         colorize=True,
     )
+
+    logger.log_once = log_once
+
+
+setup_logger()

+ 31 - 20
aphrodite/common/logits_processor.py

@@ -6,38 +6,43 @@ from typing import Dict, List
 class LogitsProcessor(ABC):
 
     @abstractmethod
-    def __call__(self, logits: torch.Tensor,
-                 output_tokens: List[List[int]]) -> None:
+    def __call__(self, output_tokens: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        """Logits are edited in-place"""
+        pass
+
+    @abstractmethod
+    def batched(self, logits: torch.Tensor,
+                output_tokens: List[List[int]]) -> None:
         """Logits are edited in-place"""
         pass
 
 
 class BiasLogitsProcessor(LogitsProcessor):
-    """This is to enable logit_bias in the OpenAI server,
-    an additive bias on the original logit values.
+    """Apply an additive bias to specific token logits.
     Args:
       biases: Dict of bias values. Each key corresponds to the the token id.
     """
 
     def __init__(self, biases: Dict[int, float]):
-        super().__init__()
+        assert biases
         self.biases = biases
-
-        if not biases:
-            return
-
         self.keys = torch.tensor(list(self.biases.keys()), dtype=torch.long)
         self.values = torch.tensor(list(self.biases.values()),
                                    dtype=torch.float)
 
-    def __call__(self, logits: torch.Tensor,
-                 output_tokens: List[List[int]]) -> None:
-        if not self.biases:
-            return
+    def __call__(self, output_tokens: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        values = self.values.to(logits.device)
+        keys = self.keys.to(logits.device)
+        logits[keys] += values
+        return logits
 
+    def batched(self, logits: torch.Tensor,
+                output_tokens: List[List[int]]) -> None:
         values = self.values.to(logits.device)
         keys = self.keys.to(logits.device)
-        logits[0, keys] += values
+        logits[:, keys] += values
 
 
 class BanEOSUntil(LogitsProcessor):
@@ -48,12 +53,18 @@ class BanEOSUntil(LogitsProcessor):
     parameters can be handled gracefully."""
 
     def __init__(self, min_tokens: int, eos_token_id: int):
-        super().__init__()
         self._min_tokens = min_tokens
         self._eos_token_id = eos_token_id
 
-    def __call__(self, logits: torch.Tensor,
-                 output_tokens: List[List[int]]) -> None:
-        for i in range(len(output_tokens)):
-            if len(output_tokens[i]) < self._min_tokens:
-                logits[i][self._eos_token_id] = -float("inf")
+    def __call__(self, output_tokens: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        if len(output_tokens) < self._min_tokens:
+            logits[self._eos_token_id] = -float("inf")
+        return logits
+
+    def batched(self, logits: torch.Tensor,
+                output_tokens: List[List[int]]) -> None:
+        terminate_mask = torch.tensor(
+            [len(toks) < self._min_tokens for toks in output_tokens],
+            device=logits.device)
+        logits[terminate_mask, self._eos_token_id] = -float("inf")

+ 17 - 11
aphrodite/common/outputs.py

@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List, Optional, Union
 import time
 
 from aphrodite.common.sequence import (
@@ -23,6 +23,9 @@ class CompletionOutput:
         logprobs: The log probabilities of the top probability words at each
             position if the logprobs are requested.
         finish_reason: The reason why the sequence is finished.
+        stop_reason: The stop string or token id that caused the completion
+            to stop, None if the completion finished for some other reason
+            including encountering the EOS token.
         lora_request: The LoRA request that was used to generate the output.
     """
 
@@ -34,6 +37,7 @@ class CompletionOutput:
         cumulative_logprob: float,
         logprobs: Optional[SampleLogprobs],
         finish_reason: Optional[str] = None,
+        stop_reason: Union[int, str, None] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
         self.index = index
@@ -42,6 +46,7 @@ class CompletionOutput:
         self.cumulative_logprob = cumulative_logprob
         self.logprobs = logprobs
         self.finish_reason = finish_reason
+        self.stop_reason = stop_reason
         self.lora_request = lora_request
 
     def finished(self) -> bool:
@@ -53,7 +58,8 @@ class CompletionOutput:
                 f"token_ids={self.token_ids}, "
                 f"cumulative_logprob={self.cumulative_logprob}, "
                 f"logprobs={self.logprobs}, "
-                f"finish_reason={self.finish_reason})")
+                f"finish_reason={self.finish_reason}, "
+                f"stop_reason={self.stop_reason})")
 
 
 class RequestOutput:
@@ -110,16 +116,16 @@ class RequestOutput:
         # NOTE: We need omit logprobs here explicitly because the sequence
         # always has the logprobs of the sampled tokens even if the
         # logprobs are not requested.
-        include_logprobs = seq_group.sampling_params.logprobs
+        include_logprobs = seq_group.sampling_params.logprobs is not None
+        text_buffer_length = seq_group.sampling_params.output_text_buffer_length
         outputs = [
-            CompletionOutput(
-                seqs.index(seq),
-                seq.output_text,
-                seq.get_output_token_ids(),
-                seq.get_cumulative_logprob(),
-                seq.output_logprobs if include_logprobs else None,
-                SequenceStatus.get_finished_reason(seq.status),
-            ) for seq in top_n_seqs
+            CompletionOutput(seqs.index(seq),
+                             seq.get_output_text_to_return(text_buffer_length),
+                             seq.get_output_token_ids(),
+                             seq.get_cumulative_logprob(),
+                             seq.output_logprobs if include_logprobs else None,
+                             SequenceStatus.get_finished_reason(seq.status),
+                             seq.stop_reason) for seq in top_n_seqs
         ]
 
         # Every sequence in the sequence group should have the same prompt.

+ 59 - 2
aphrodite/common/sampling_params.py

@@ -2,9 +2,10 @@
 import copy
 from enum import IntEnum
 from functools import cached_property
-from typing import Callable, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
 import torch
+from pydantic import conint
 
 _SAMPLING_EPS = 1e-5
 
@@ -106,6 +107,8 @@ class SamplingParams:
         ignore_eos: Whether to ignore the EOS token and continue generating
             tokens after the EOS token is generated.
         max_tokens: Maximum number of tokens to generate per output sequence.
+        min_tokens: Minimum number of tokens to generate per output sequence
+            before EOS or stop tokens are generated.
         logprobs: Number of log probabilities to return per output token.
             Note that the implementation follows the OpenAI API: The return
             result includes the log probabilities on the `logprobs` most likely
@@ -113,6 +116,7 @@ class SamplingParams:
             log probability of the sampled token, so there  may be up to
             `logprobs+1` elements in the response.
         prompt_logprobs: Number of log probabilities to return per prompt token.
+        detokenize: Whether to detokenize the output. Defaults to True.
         custom_token_bans: List of token IDs to ban from generating
         skip_special_tokens: Whether to skip special tokens in the output.
             defaults to true.
@@ -120,6 +124,9 @@ class SamplingParams:
             tokens in the output. Defaults to True.
         logits_processors: List of LogitsProcessors to change the probability
             of token prediction at runtime.
+        truncate_prompt_tokens: If set to an integer k, will use only the last
+            k tokens from the prompt (i.e. left-truncation). Defaults to None
+            (i.e. no truncation).
     """
 
     def __init__(
@@ -155,12 +162,15 @@ class SamplingParams:
         include_stop_str_in_output: bool = False,
         ignore_eos: bool = False,
         max_tokens: Optional[int] = 16,
+        min_tokens: int = 0,
         logprobs: Optional[int] = None,
         prompt_logprobs: Optional[int] = None,
+        detokenize: bool = True,
         custom_token_bans: Optional[List[int]] = None,
         skip_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
         logits_processors: Optional[List[LogitsProcessorFunc]] = None,
+        truncate_prompt_tokens: Optional[conint(ge=1)] = None,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -197,13 +207,19 @@ class SamplingParams:
         self.stop_token_ids = stop_token_ids or []
         self.ignore_eos = ignore_eos
         self.max_tokens = max_tokens
+        self.min_tokens = min_tokens
         self.logprobs = logprobs
         self.prompt_logprobs = prompt_logprobs
+        # NOTE: This parameter is only exposed at the engine level for now.
+        # It is not exposed in the OpenAI API server, as the OpenAI API does
+        # not support returning only a list of token IDs.
+        self.detokenize = detokenize
         self.custom_token_bans = custom_token_bans or []
         self.skip_special_tokens = skip_special_tokens
         self.spaces_between_special_tokens = spaces_between_special_tokens
         self.logits_processors = logits_processors or []
         self.include_stop_str_in_output = include_stop_str_in_output
+        self.truncate_prompt_tokens = truncate_prompt_tokens
 
         self.default_values = {
             "n": 1,
@@ -236,14 +252,24 @@ class SamplingParams:
             "stop_token_ids": [],
             "ignore_eos": False,
             "max_tokens": 16,
+            "min_tokens": 0,
             "logprobs": None,
             "prompt_logprobs": None,
+            "detokenize": True,
             "custom_token_bans": [],
             "skip_special_tokens": True,
             "spaces_between_special_tokens": True,
-            "include_stop_str_in_output": False
+            "include_stop_str_in_output": False,
+            "truncate_prompt_tokens": None,
         }
 
+        # Number of characters to hold back for stop string evaluation
+        # until sequence is finished.
+        if self.stop and not include_stop_str_in_output:
+            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+        else:
+            self.output_text_buffer_length = 0
+
         self._verify_args()
         if self.use_beam_search:
             self._verify_beam_search()
@@ -256,6 +282,8 @@ class SamplingParams:
                 self.min_p = 0.0
                 self.top_a = 0.0
                 self._verify_greedy_sampling()
+        # injected by the engine
+        self.eos_token_id = None
 
     def _verify_args(self) -> None:
         if self.n < 1:
@@ -325,12 +353,29 @@ class SamplingParams:
         if self.max_tokens is not None and self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
+        if self.min_tokens < 0:
+            raise ValueError(f"min_tokens must be greater than or equal to 0, "
+                             f"got {self.min_tokens}.")
+        if self.max_tokens is not None and self.min_tokens > self.max_tokens:
+            raise ValueError(
+                f"min_tokens must be less than or equal to "
+                f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
         if self.logprobs is not None and self.logprobs < 0:
             raise ValueError(
                 f"logprobs must be non-negative, got {self.logprobs}.")
         if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
             raise ValueError("prompt_logprobs must be non-negative, got "
                              f"{self.prompt_logprobs}.")
+        if (self.truncate_prompt_tokens is not None
+                and self.truncate_prompt_tokens < 1):
+            raise ValueError(f"truncate_prompt_tokens must be >= 1, "
+                             f"got {self.truncate_prompt_tokens}")
+        if any(not stop_str for stop_str in self.stop):
+            raise ValueError("stop cannot contain an empty string.")
+        if self.stop and not self.detokenize:
+            raise ValueError(
+                "stop strings are only supported when detokenize is True. "
+                "Set detokenize=True to use stop.")
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:
@@ -366,6 +411,18 @@ class SamplingParams:
         if self.top_k != -1:
             raise ValueError("top_k must be -1 when using greedy sampling.")
 
+    def update_from_generation_config(
+            self, generation_config: Dict[str, Any]) -> None:
+        """Update if there are non-default values from generation_config"""
+        # Update eos_token_id for generation
+        if eos_ids := generation_config.get("eos_token_id"):
+            # it can be either int or list of int
+            if isinstance(eos_ids, int):
+                eos_ids = [eos_ids]
+            original_stop_token_ids = set(self.stop_token_ids)
+            original_stop_token_ids.update(eos_ids)
+            self.stop_token_ids = list(original_stop_token_ids)
+
     @cached_property
     def sampling_type(self) -> SamplingType:
         if self.use_beam_search:

+ 150 - 3
aphrodite/common/sequence.py

@@ -15,9 +15,14 @@ if TYPE_CHECKING:
 
 @dataclass
 class Logprob:
-    """Infos for supporting OpenAI compatible logprobs."""
-
+    """Infos for supporting OpenAI compatible logprobs and token ranks.
+    Attributes:
+        logprob: The logprob of chosen token
+        rank: The vocab rank of chosen token (>=1)
+        decoded_token: The decoded chosen token index
+    """
     logprob: float
+    rank: Optional[int] = None
     decoded_token: Optional[str] = None
 
 
@@ -63,6 +68,11 @@ class SequenceStatus(enum.Enum):
         return finish_reason
 
 
+class SequenceStage(enum.Enum):
+    PREFILL = enum.auto()
+    DECODE = enum.auto()
+
+
 @dataclass
 class RequestMetrics:
     """Metrics associated with a request.
@@ -108,6 +118,9 @@ class SequenceData:
         self.prompt_token_ids = prompt_token_ids
         self.output_token_ids = output_token_ids
         self.cumulative_logprob = 0.0
+        # The number of tokens that are computed (that run against the model).
+        self._num_computed_tokens = 0
+        self._stage: SequenceStage = SequenceStage.PREFILL
 
     def append_token_id(self, token_id: int, logprob: float) -> None:
         self.output_token_ids.append(token_id)
@@ -125,6 +138,34 @@ class SequenceData:
     def get_token_ids(self) -> List[int]:
         return self.prompt_token_ids + self.output_token_ids
 
+    def get_num_computed_tokens(self) -> int:
+        """Return the number of prefill tokens that are already computed."""
+        return self._num_computed_tokens
+
+    def update_num_computed_tokens(self, num_new_computed_tokens: int):
+        """Update number of tokens computed so far."""
+        self._num_computed_tokens += num_new_computed_tokens
+        assert self._num_computed_tokens <= self.get_len(), (
+            self._num_computed_tokens, self.get_len())
+        # If all tokens are computed, it means it is in decoding phase.
+        if self.get_num_uncomputed_tokens() == 0:
+            self._stage = SequenceStage.DECODE
+
+    def reset_state_for_recompute(self) -> None:
+        """Reset the number of computed tokens from this sequence. It is
+        supposed to be called when a sequence needs to be started from
+        the beginning again (e.g., sequence is preempted).
+        """
+        self._num_computed_tokens = 0
+        self._stage = SequenceStage.PREFILL
+
+    def get_num_uncomputed_tokens(self) -> int:
+        """Return the number of prefil tokens that are not computed."""
+        # we use `get_len()` which includes prompt_len + output_len instead
+        # of prompt_len here. This is because during recompute we need to
+        # prefill for both prompt and output.
+        return self.get_len() - self.get_num_computed_tokens()
+
     def get_last_token_id(self) -> int:
         if not self.output_token_ids:
             return self.prompt_token_ids[-1]
@@ -136,6 +177,10 @@ class SequenceData:
     def get_output_token_ids(self) -> int:
         return self.output_token_ids
 
+    @property
+    def stage(self) -> SequenceStage:
+        return self._stage
+
     def __repr__(self) -> str:
         return (f"SequenceData("
                 f"prompt_token_ids={self.prompt_token_ids}, "
@@ -178,6 +223,7 @@ class Sequence:
         # Initialize the logical token blocks with the prompt token ids.
         self._append_tokens_to_blocks(prompt_token_ids)
         self.status = SequenceStatus.WAITING
+        self.stop_reason: Union[int, str, None] = None
 
         # Used for incremental detokenization
         self.prefix_offset = 0
@@ -190,6 +236,12 @@ class Sequence:
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
 
+    def get_output_text_to_return(self, buffer_length: int):
+        # We return the full output text if the sequence is finished.
+        truncate = buffer_length and not self.is_finished()
+        return self.output_text[:-buffer_length] if truncate else (
+            self.output_text)
+
     def hash_of_block(self, logical_idx: int) -> int:
         # Compute the number of tokens in the sequence
         # TODO: The current hashing function is O(L^2). We should optimize
@@ -201,6 +253,10 @@ class Sequence:
     def num_hashed_tokens_of_block(self, logical_idx: int):
         return logical_idx * self.block_size + self.block_size
 
+    def reset_state_for_recompute(self):
+        """Reset the sequence states for recomputation."""
+        self.data.reset_state_for_recompute()
+
     def _append_logical_block(self) -> None:
         block = LogicalTokenBlock(
             block_number=len(self.logical_token_blocks),
@@ -287,6 +343,22 @@ class Sequence:
         new_seq.seq_id = new_seq_id
         return new_seq
 
+    def get_num_new_tokens(self) -> int:
+        """Get the number of new tokens to be computed.
+        Args:
+            remainig_token_budget: The remaining token budgets.
+        Returns:
+            The new number of tokens to be computed. I.e., 1 for decode, prompt
+            size for prefill. If there's not enough remainig_token_budget, it
+            can return the chunked number of new tokens.
+        """
+        if self.data.stage == SequenceStage.DECODE:
+            return 1
+        return self.data.get_num_uncomputed_tokens()
+
+    def is_prefill(self) -> bool:
+        return self.data.stage == SequenceStage.PREFILL
+
     def __repr__(self) -> str:
         return (f"Sequence(seq_id={self.seq_id}, "
                 f"status={self.status.name}, "
@@ -301,6 +373,25 @@ class SequenceGroupState:
     generator: Optional = None
 
 
+class MultiModalData:
+    """Multi modal request.
+    
+    Args:
+        type: The data type.
+        data: The actual data.
+        The required shape and semantic meaning of it depends on the vision
+        language config of the hosted model. 
+        See `VisionLanguageConfig` in `config.py`.
+    """
+
+    class Type(enum.Enum):
+        IMAGE = enum.auto()
+
+    def __init__(self, type: Type, data: "torch.Tensor"):
+        self.type = type
+        self.data = data
+
+
 class SequenceGroup:
     """A group of sequences that are generated from the same prompt.
 
@@ -310,6 +401,7 @@ class SequenceGroup:
         sampling_params: The sampling parameters used to generate the outputs.
         arrival_time: The arrival time of the request.
         lora_request: LoRA request.
+        multi_modal_requst: Multi modal data for the request.
     """
 
     def __init__(
@@ -319,6 +411,7 @@ class SequenceGroup:
         sampling_params: SamplingParams,
         arrival_time: float,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         self.request_id = request_id
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@@ -333,6 +426,7 @@ class SequenceGroup:
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
         self.state = SequenceGroupState()
+        self.multi_modal_data = multi_modal_data
 
     @property
     def prompt(self) -> str:
@@ -405,7 +499,25 @@ class SequenceGroup:
     def get_finished_seqs(self) -> List[Sequence]:
         return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
 
+    def update_num_computed_tokens(self, num_new_computed_tokens: int):
+        """Update number of tokens computed so far."""
+        for seq in self.seqs_dict.values():
+            if not seq.is_finished():
+                seq.data.update_num_computed_tokens(num_new_computed_tokens)
+
+    def get_num_uncomputed_tokens(self) -> int:
+        num_uncomputed_tokens = 0
+        for seq in self.get_seqs():
+            if not seq.is_finished():
+                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
+        return num_uncomputed_tokens
+
     def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
+        # Optimization. We don't need to call get_seqs if we don't need to
+        # filter by states.
+        if status is None:
+            return len(self.seqs_dict)
+
         return len(self.get_seqs(status))
 
     def num_unfinished_seqs(self) -> int:
@@ -432,6 +544,10 @@ class SequenceGroup:
     def is_finished(self) -> bool:
         return all(seq.is_finished() for seq in self.get_seqs())
 
+    def is_prefill(self) -> bool:
+        # Every sequences should be in the same stage.
+        return self.get_seqs()[0].is_prefill()
+
     def __repr__(self) -> str:
         return (f"SequenceGroup(request_id={self.request_id}, "
                 f"sampling_params={self.sampling_params}, "
@@ -439,7 +555,7 @@ class SequenceGroup:
 
 
 class SequenceGroupMetadata:
-    """Metadata for a sequence group. Used to create `InputMetadata`.
+    """Metadata for a sequence group. Used to create `AttentionMetadata`.
 
     Args:
         request_id: The ID of the request.
@@ -448,8 +564,11 @@ class SequenceGroupMetadata:
         sampling_params: The sampling parameters used to generate the outputs.
         block_tables: The block tables. (Seq id -> list of physical block
             numbers)
+        token_chunk_size: The number of tokens to be processed (per sequence).
+            None if chunking is not required.
         state: Internal state tied to this sequence group.
         lora_request: LoRA request.
+        multi_modal_data: Multi modal data for the request.
         persistent_data: The persistent data of the sequence group.
     """
 
@@ -461,9 +580,11 @@ class SequenceGroupMetadata:
         sampling_params: SamplingParams,
         block_tables: Dict[int, List[int]],
         persistent_data: Dict[int, dict],
+        token_chunk_size: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
         computed_block_nums: Optional[List[int]] = None,
         state: Optional[SequenceGroupState] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         self.request_id = request_id
         self.is_prompt = is_prompt
@@ -474,11 +595,24 @@ class SequenceGroupMetadata:
         self.lora_request = lora_request
         self.computed_block_nums = computed_block_nums
         self.state = SequenceGroupState() if state is None else state
+        self.multi_modal_data = multi_modal_data
+        self._token_chunk_size = token_chunk_size
+
+        if self._token_chunk_size is None:
+            if is_prompt:
+                self._token_chunk_size = list(seq_data.values())[0].get_len()
+            else:
+                self._token_chunk_size = 1
 
     @property
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
 
+    @property
+    def token_chunk_size(self) -> int:
+        """Return the number of tokens to be processed (chunk size)."""
+        return self._token_chunk_size
+
 
 class SequenceOutput:
     """The model output associated with a sequence.
@@ -573,3 +707,16 @@ class SamplerOutput:
     def __eq__(self, other: object):
         return (isinstance(other, self.__class__)
                 and self.outputs == other.outputs)
+
+    def __repr__(self) -> str:
+        """Show the shape of a tensor instead of its values to reduce noise.
+        """
+        sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
+                                    else self.sampled_token_probs.shape)
+        sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
+                                  self.sampled_token_ids.shape)
+        return (
+            f"SamplerOutput(outputs={self.outputs}, "
+            f"sampled_token_probs={sampled_token_probs_repr}, "
+            f"sampled_token_ids={sampled_token_ids_repr}, "
+            f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")

+ 14 - 9
aphrodite/common/test_utils.py

@@ -1,8 +1,10 @@
 import ray
 
-from aphrodite.common.config import ParallelConfig
 from aphrodite.common.utils import get_open_port
-from aphrodite.task_handler.worker import init_distributed_environment
+from aphrodite.distributed import (
+    ensure_model_parallel_initialized,
+    init_distributed_environment,
+)
 
 
 def init_test_distributed_environment(
@@ -10,13 +12,16 @@ def init_test_distributed_environment(
     tensor_parallel_size: int,
     rank: int,
     distributed_init_port: str,
+    local_rank: int = -1,
 ) -> None:
-    parallel_config = ParallelConfig(pipeline_parallel_size,
-                                     tensor_parallel_size,
-                                     worker_use_ray=True)
     distributed_init_method = f"tcp://localhost:{distributed_init_port}"
-    init_distributed_environment(parallel_config, rank,
-                                 distributed_init_method)
+    init_distributed_environment(
+        world_size=pipeline_parallel_size * tensor_parallel_size,
+        rank=rank,
+        distributed_init_method=distributed_init_method,
+        local_rank=local_rank)
+    ensure_model_parallel_initialized(tensor_parallel_size,
+                                      pipeline_parallel_size)
 
 
 def multi_process_tensor_parallel(
@@ -26,7 +31,6 @@ def multi_process_tensor_parallel(
     # Using ray helps debugging the error when it failed
     # as compared to multiprocessing.
     ray.init()
-
     distributed_init_port = get_open_port()
     refs = []
     for rank in range(tensor_parallel_size):
@@ -35,4 +39,5 @@ def multi_process_tensor_parallel(
                                distributed_init_port))
     ray.get(refs)
 
-    ray.shutdown()
+
+ray.shutdown()

+ 171 - 29
aphrodite/common/utils.py

@@ -1,25 +1,20 @@
+import asyncio
 import enum
+import gc
 import os
 import socket
 import subprocess
 import uuid
-import gc
+from collections import OrderedDict, defaultdict
+from functools import lru_cache, partial
 from platform import uname
-from typing import List, Tuple, Union, Generic
-from packaging.version import parse, Version
+from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
+                    Hashable, List, Optional, Tuple, TypeVar, Union)
 
 import psutil
 import torch
-import asyncio
-from functools import partial, lru_cache
-from typing import (
-    Awaitable,
-    Callable,
-    TypeVar,
-)
-from collections import OrderedDict
-from typing import Any, Hashable, Optional
 from loguru import logger
+from packaging.version import Version, parse
 
 T = TypeVar("T")
 
@@ -27,8 +22,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.half,
     "bfloat16": torch.bfloat16,
     "float": torch.float,
-    "fp8_e5m2": torch.uint8,
-    # "int8": torch.int8,
+    "fp8": torch.uint8,
 }
 
 
@@ -120,6 +114,15 @@ def is_hip() -> bool:
     return torch.version.hip is not None
 
 
+@lru_cache(maxsize=None)
+def is_cpu() -> bool:
+    from importlib.metadata import PackageNotFoundError, version
+    try:
+        return "cpu" in version("aphrodite-engine")
+    except PackageNotFoundError:
+        return False
+
+
 @lru_cache(maxsize=None)
 def is_neuron() -> bool:
     try:
@@ -175,6 +178,48 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
     return _async_wrapper
 
 
+def merge_async_iterators(
+        *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
+    """Merge multiple asynchronous iterators into a single iterator.
+    This method handle the case where some iterators finish before others.
+    When it yields, it yields a tuple (i, item) where i is the index of the
+    iterator that yields the item.
+    """
+    queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
+
+    finished = [False] * len(iterators)
+
+    async def producer(i: int, iterator: AsyncIterator[T]):
+        try:
+            async for item in iterator:
+                await queue.put((i, item))
+        except Exception as e:
+            await queue.put(e)
+        finished[i] = True
+
+    _tasks = [
+        asyncio.create_task(producer(i, iterator))
+        for i, iterator in enumerate(iterators)
+    ]
+
+    async def consumer():
+        try:
+            while not all(finished) or not queue.empty():
+                item = await queue.get()
+                if isinstance(item, Exception):
+                    raise item
+                yield item
+        except (Exception, asyncio.CancelledError) as e:
+            for task in _tasks:
+                # NOTE: Pass the error msg in cancel()
+                # when only Python 3.9+ is supported.
+                task.cancel()
+            raise e
+        await asyncio.gather(*_tasks)
+
+    return consumer()
+
+
 def get_ip() -> str:
     # try ipv4
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -189,7 +234,9 @@ def get_ip() -> str:
 
 
 def get_distributed_init_method(ip: str, port: int) -> str:
-    return f"tcp://{ip}:{port}"
+    # Brackets are not permitted in ipv4 addresses,
+    # see https://github.com/python/cpython/issues/103848
+    return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
 
 
 def get_open_port() -> int:
@@ -209,6 +256,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
     os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
 
 
+def chunk_list(lst, chunk_size):
+    """Yield successive chunk_size chunks from lst."""
+    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def cdiv(a: int, b: int) -> int:
+    """Ceiling division."""
+    return -(a // -b)
+
+
 @lru_cache(maxsize=None)
 def get_nvcc_cuda_version() -> Optional[Version]:
     cuda_home = os.environ.get('CUDA_HOME')
@@ -230,7 +287,7 @@ def get_nvcc_cuda_version() -> Optional[Version]:
     return nvcc_cuda_version
 
 
-def _generate_random_fp8_e5m2(
+def _generate_random_fp8(
     tensor: torch.tensor,
     low: float,
     high: float,
@@ -246,7 +303,7 @@ def _generate_random_fp8_e5m2(
     from aphrodite._C import cache_ops
     tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
     tensor_tmp.uniform_(low, high)
-    cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
+    cache_ops.convert_fp8(tensor_tmp, tensor)
     del tensor_tmp
 
 
@@ -275,7 +332,7 @@ def create_kv_caches_with_random(
                 raise ValueError(f"Invalid model dtype: {model_dtype}")
         elif cache_dtype in ["half", "bfloat16", "float"]:
             torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
-        elif cache_dtype == "fp8_e5m2":
+        elif cache_dtype == "fp8":
             torch_dtype = torch.uint8
         else:
             raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
@@ -292,12 +349,10 @@ def create_kv_caches_with_random(
         key_cache = torch.empty(size=key_cache_shape,
                                 dtype=torch_dtype,
                                 device=device)
-        if cache_dtype == 'fp8_e5m2':
-            _generate_random_fp8_e5m2(key_cache, -scale, scale)
-        # elif cache_dtype == 'int8':
-        #     torch.randint(-128, 127, key_cache.size(), out=key_cache)
-        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
+        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
             key_cache.uniform_(-scale, scale)
+        elif cache_dtype == 'fp8':
+            _generate_random_fp8(key_cache, -scale, scale)
         else:
             raise ValueError(
                 f"Does not support key cache of type {cache_dtype}")
@@ -309,12 +364,10 @@ def create_kv_caches_with_random(
         value_cache = torch.empty(size=value_cache_shape,
                                   dtype=torch_dtype,
                                   device=device)
-        if cache_dtype == 'fp8_e5m2':
-            _generate_random_fp8_e5m2(value_cache, -scale, scale)
-        # elif cache_dtype == 'int8':
-        #     torch.randint(-128, 127, value_cache.size(), out=value_cache)
-        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
+        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
             value_cache.uniform_(-scale, scale)
+        elif cache_dtype == 'fp8':
+            _generate_random_fp8(value_cache, -scale, scale)
         else:
             raise ValueError(
                 f"Does not support value cache of type {cache_dtype}")
@@ -322,7 +375,29 @@ def create_kv_caches_with_random(
     return key_caches, value_caches
 
 
-class measure_cuda_memory:
+@lru_cache
+def print_warning_once(msg: str) -> None:
+    logger.warning(msg)
+
+
+@lru_cache(maxsize=None)
+def is_pin_memory_available() -> bool:
+
+    if in_wsl():
+        # Pinning memory in WSL is not supported.
+        # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
+        print_warning_once("Using 'pin_memory=False' as WSL is detected. "
+                           "This may slow down the performance.")
+        return False
+    elif is_neuron():
+        print_warning_once("Pin memory is not supported on Neuron.")
+        return False
+    elif is_cpu():
+        return False
+    return True
+
+
+class CudaMemoryProfiler:
 
     def __init__(self, device=None):
         self.device = device
@@ -344,3 +419,70 @@ class measure_cuda_memory:
 
         # Force garbage collection
         gc.collect()
+
+
+def str_to_int_tuple(s: str) -> Tuple[int]:
+    """Convert a string to a tuple of integers."""
+    try:
+        return tuple(map(int, s.split(",")))
+    except ValueError as e:
+        raise ValueError(
+            "String must be a series of integers separated by commas "
+            f"(e.g., 1, 2, 3). Given input: {s}") from e
+
+
+def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
+    assert len(x) <= max_len
+    return x + [pad] * (max_len - len(x))
+
+
+def make_tensor_with_pad(
+    x: List[List[int]],
+    max_len: int,
+    pad: int,
+    dtype: torch.dtype,
+    device: Optional[Union[str, torch.device]],
+) -> torch.Tensor:
+    """Make a padded tensor of a 2D inputs.
+    The padding is applied to the end of each inner list until it reaches
+    `max_len`.
+    """
+    padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
+    return torch.tensor(padded_x, dtype=dtype, device=device)
+
+
+def async_tensor_h2d(
+    data: list,
+    dtype: torch.dtype,
+    target_device: Union[str, torch.device],
+    pin_memory: bool,
+) -> torch.Tensor:
+    """Asynchronously create a tensor and copy it from host to device."""
+    t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
+    return t.to(device=target_device, non_blocking=True)
+
+
+def maybe_expand_dim(tensor: torch.Tensor,
+                     target_dims: int,
+                     size: int = 1) -> torch.Tensor:
+    """Expand the tensor to the target_dims."""
+    if tensor.ndim < target_dims:
+        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
+    return tensor
+
+
+def merge_dicts(dict1: Dict[Any, List[Any]],
+                dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
+    """Merge 2 dicts that have key -> List of items.
+    
+    When a key conflicts, the values in dict1 is prioritized.
+    """
+    merged_dict = defaultdict(list)
+
+    for key, value in dict1.items():
+        merged_dict[key].extend(value)
+
+    for key, value in dict2.items():
+        merged_dict[key].extend(value)
+
+    return dict(merged_dict)

+ 3 - 0
aphrodite/distributed/__init__.py

@@ -0,0 +1,3 @@
+from .communication_op import *
+from .parallel_state import *
+from .utils import *

+ 17 - 9
aphrodite/modeling/megatron/communication_op.py → aphrodite/distributed/communication_op.py

@@ -1,17 +1,15 @@
 from collections import namedtuple
 from typing import Any, Dict, List, Optional, Union
 
-from torch.distributed import ProcessGroup
 import torch
+from torch.distributed import ProcessGroup
 
-from aphrodite.modeling.megatron import cupy_utils
-from aphrodite.modeling.megatron.parallel_state import (
+from .parallel_state import (
+    get_tensor_model_parallel_group,
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
-    get_tensor_model_parallel_group,
-    is_cupy_nccl_enabled_for_all_reduce,
+    is_pynccl_enabled_for_all_reduce,
 )
-from aphrodite.modeling.megatron.custom_all_reduce import custom_all_reduce
 
 
 def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -25,15 +23,18 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
     TLDR: always assume this function modifies its input, but use the return
     value as the output.
     """
+    from aphrodite.distributed.device_communicators import pynccl_utils
+    from aphrodite.distributed.device_communicators.custom_all_reduce import (
+        custom_all_reduce)
     # Bypass the function if we are using only 1 GPU.
     if get_tensor_model_parallel_world_size() == 1:
         return input_
     out = custom_all_reduce(input_)
     if out is not None:
         return out
-    if is_cupy_nccl_enabled_for_all_reduce():
+    if is_pynccl_enabled_for_all_reduce():
         # TODO: support multiple parallel groups.
-        cupy_utils.all_reduce(input_)
+        pynccl_utils.all_reduce(input_)
     else:
         torch.distributed.all_reduce(input_,
                                      group=get_tensor_model_parallel_group())
@@ -172,10 +173,17 @@ def broadcast_tensor_dict(
         torch.distributed.broadcast_object_list([metadata_list],
                                                 src=src,
                                                 group=group)
+        async_handles = []
         for key, value in metadata_list:
             if isinstance(value, TensorMetadata):
                 tensor = tensor_dict[key]
-                torch.distributed.broadcast(tensor, src=src, group=group)
+                async_handles.append(
+                    torch.distributed.broadcast(tensor,
+                                                src=src,
+                                                group=group,
+                                                async_op=True))
+        for async_handle in async_handles:
+            async_handle.wait()
     else:
         recv_metadata_list = [None]
         torch.distributed.broadcast_object_list(recv_metadata_list,

+ 0 - 0
aphrodite/modeling/megatron/__init__.py → aphrodite/distributed/device_communicators/__init__.py


+ 36 - 11
aphrodite/modeling/megatron/custom_all_reduce.py → aphrodite/distributed/device_communicators/custom_all_reduce.py

@@ -5,8 +5,7 @@ from loguru import logger
 import torch
 import torch.distributed as dist
 
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank)
+from aphrodite.distributed import gpu_p2p_access_check
 
 try:
     from aphrodite._C import custom_ar
@@ -22,6 +21,8 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
 
 
 def init_custom_ar() -> None:
+    from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                       get_tensor_model_parallel_world_size)
     global _CA_HANDLE
     if _CA_HANDLE is not None:
         return
@@ -32,17 +33,38 @@ def init_custom_ar() -> None:
     if world_size not in _SUPPORTED_WORLD_SIZES:
         logger.warning(
             "Custom allreduce is disabled due to an unsupported world size: "
-            "%d. Supported world sizes: %s. To silence this warning, specify "
-            "disable_custom_all_reduce=True explicitly.", world_size,
+            "%d. Supported world sizes: %s. To silence this warning, specify"
+            " disable_custom_all_reduce=True explicitly.", world_size,
             str(_SUPPORTED_WORLD_SIZES))
         return
+    num_dev = torch.cuda.device_count()
+    # note: num dev can be larger than world_size if we're only using
+    # first few GPUs
+    if num_dev < world_size:
+        logger.warning(
+            "Cannot test GPU P2P because not all GPUs are visible to the "
+            "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
+            " is set.")
+        return False
+    # test nvlink first, this will filter out most of the cases
+    # where custom allreduce is not supported
+    full_nvlink = _is_full_nvlink(rank, world_size)
+    if world_size > 2 and not full_nvlink:
+        logger.warning(
+            "Custom allreduce is disabled because it's not supported on more"
+            " than two PCIe-only GPUs. To silence this warning, specify"
+            " disable_custom_all_reduce=True explicitly.")
+        return
+    # test P2P capability
+    # this is expensive to compute at the first time
+    # then we cache the result
     if not _can_p2p(rank, world_size):
         logger.warning(
             "Custom allreduce is disabled because your platform lacks GPU P2P"
-            " capability. To silence this warning, specify "
-            "disable_custom_all_reduce=True explicitly.")
+            " capability or P2P test failed. To silence this warning, specify"
+            " disable_custom_all_reduce=True explicitly.")
         return
-    _CA_HANDLE = CustomAllreduce(rank, world_size)
+    _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
 
 
 def begin_capture() -> None:
@@ -134,7 +156,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
     for i in range(world_size):
         if i == rank:
             continue
-        if not torch.cuda.can_device_access_peer(rank, i):
+        if not gpu_p2p_access_check(rank, i):
             return False
     return True
 
@@ -142,7 +164,11 @@ def _can_p2p(rank: int, world_size: int) -> bool:
 class CustomAllreduce:
 
     # max_size: max supported allreduce size
-    def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
+    def __init__(self,
+                 rank,
+                 world_size,
+                 full_nvlink,
+                 max_size=8192 * 1024) -> None:
         # buffers memory are owned by this Python class and passed to C++
         # meta data composes of two parts: meta data for synchronization
         # (256 bytes) and a temporary buffer for storing intermediate
@@ -164,11 +190,10 @@ class CustomAllreduce:
         self.max_size = max_size
         self.world_size = world_size
         handles, offsets = self._get_ipc_meta(self.meta)
-        self.full_nvlink = _is_full_nvlink(rank, world_size)
+        self.full_nvlink = full_nvlink
         self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
                                              handles, offsets, rank,
                                              self.full_nvlink)
-        self.fast_cond = self.full_nvlink or world_size <= 2
         self.register_buffer(self.buffer)
 
     def _get_ipc_meta(self, inp: torch.Tensor):

+ 269 - 0
aphrodite/distributed/device_communicators/pynccl.py

@@ -0,0 +1,269 @@
+# This file is a pure Python wrapper for the NCCL library.
+# The main purpose is to use NCCL combined with CUDA graph.
+# Before writing this script, we tried the following approach:
+# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
+#  often gets stuck when initializing the NCCL communicator.
+# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
+#  contains many other potential cuda APIs, that are not allowed during
+#  capturing the CUDA graph. For further details, please check
+# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
+#
+# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
+# doable, but we often encounter issues related with nccl versions, and need
+# to switch between different versions of NCCL. See
+# https://github.com/NVIDIA/nccl/issues/1234 for more details.
+# A C/C++ binding is not flexible enough to handle this. It requires
+# recompilation of the code every time we want to switch between different
+# versions. This current implementation, with a **pure** Python wrapper, is
+# more flexible. We can easily switch between different versions of NCCL by
+# changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
+# variable in the code.
+
+import ctypes
+import datetime
+import logging
+import os
+
+# ===================== import region =====================
+import torch
+import torch.distributed as dist
+from torch.distributed import ReduceOp
+
+logger = logging.getLogger(__name__)
+
+so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
+
+# manually load the nccl library
+if so_file:
+    logger.info(
+        f"Loading nccl from env variable APHRODITE_NCCL_SO_PATH={so_file}")
+else:
+    if torch.version.cuda is not None:
+        so_file = "libnccl.so.2"
+    elif torch.version.hip is not None:
+        so_file = "librccl.so.1"
+    else:
+        raise ValueError("NCCL only supports CUDA and ROCm backends.")
+    logger.debug(f"Loading nccl from library {so_file}")
+
+try:
+    nccl = ctypes.CDLL(so_file)
+except Exception as e:
+    logger.error(
+        f"Failed to load NCCL library from {so_file} ."
+        "It is expected if you are not running on NVIDIA/AMD GPUs."
+        "Otherwise please set the environment variable APHRODITE_NCCL_SO_PATH"
+        " to point to the correct nccl library path. You can install nccl"
+        " with `conda install nccl` or `pip install nvidia-nccl-cu12`")
+    raise e
+
+# === export types and functions from nccl to Python ===
+# for the original nccl definition, please check
+# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
+
+ncclResult_t = ctypes.c_int
+
+# equivalent to c declaration:
+# ncclResult_t  ncclGetVersion(int *version);
+_c_ncclGetVersion = nccl.ncclGetVersion
+_c_ncclGetVersion.restype = ctypes.c_int
+_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
+
+
+def ncclGetVersion() -> str:
+    version = ctypes.c_int()
+    result = _c_ncclGetVersion(ctypes.byref(version))
+    assert result == 0
+    # something like 21903 --> "2.19.3"
+    version_str = str(version.value)
+    major = version_str[0].lstrip("0")
+    minor = version_str[1:3].lstrip("0")
+    patch = version_str[3:].lstrip("0")
+    return f"{major}.{minor}.{patch}"
+
+
+class NcclUniqueId(ctypes.Structure):
+    _fields_ = [("internal", ctypes.c_byte * 128)]
+
+
+# equivalent to c declaration:
+# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
+_c_ncclGetUniqueId = nccl.ncclGetUniqueId
+_c_ncclGetUniqueId.restype = ctypes.c_int
+_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
+
+
+def ncclGetUniqueId() -> NcclUniqueId:
+    unique_id = NcclUniqueId()
+    result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
+    assert result == 0
+    return unique_id
+
+
+# equivalent to c declaration:
+# ncclResult_t  ncclCommInitRank(
+#   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
+# note that ncclComm_t is a pointer type, so the first argument
+# is a pointer to a pointer
+_c_ncclCommInitRank = nccl.ncclCommInitRank
+_c_ncclCommInitRank.restype = ctypes.c_int
+_c_ncclCommInitRank.argtypes = [
+    ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
+]
+
+
+# enums
+class ncclDataType_t(ctypes.c_int):
+    ncclInt8 = 0
+    ncclChar = 0
+    ncclUint8 = 1
+    ncclInt32 = 2
+    ncclInt = 2
+    ncclUint32 = 3
+    ncclInt64 = 4
+    ncclUint64 = 5
+    ncclFloat16 = 6
+    ncclHalf = 6
+    ncclFloat32 = 7
+    ncclFloat = 7
+    ncclFloat64 = 8
+    ncclDouble = 8
+    ncclBfloat16 = 9
+    ncclNumTypes = 10
+
+    @classmethod
+    def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
+        if dtype == torch.int8:
+            return cls.ncclInt8
+        if dtype == torch.uint8:
+            return cls.ncclUint8
+        if dtype == torch.int32:
+            return cls.ncclInt32
+        if dtype == torch.int64:
+            return cls.ncclInt64
+        if dtype == torch.float16:
+            return cls.ncclFloat16
+        if dtype == torch.float32:
+            return cls.ncclFloat32
+        if dtype == torch.float64:
+            return cls.ncclFloat64
+        if dtype == torch.bfloat16:
+            return cls.ncclBfloat16
+        raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+class ncclRedOp_t(ctypes.c_int):
+    ncclSum = 0
+    ncclProd = 1
+    ncclMax = 2
+    ncclMin = 3
+    ncclAvg = 4
+    ncclNumOps = 5
+
+    @classmethod
+    def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
+        if op == ReduceOp.SUM:
+            return cls.ncclSum
+        if op == ReduceOp.PRODUCT:
+            return cls.ncclProd
+        if op == ReduceOp.MAX:
+            return cls.ncclMax
+        if op == ReduceOp.MIN:
+            return cls.ncclMin
+        if op == ReduceOp.AVG:
+            return cls.ncclAvg
+        raise ValueError(f"Unsupported op: {op}")
+
+
+# equivalent to c declaration:
+# ncclResult_t  ncclAllReduce(
+#   const void* sendbuff, void* recvbuff, size_t count,
+#   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
+#   udaStream_t stream);
+# note that cudaStream_t is a pointer type, so the last argument is a pointer
+_c_ncclAllReduce = nccl.ncclAllReduce
+_c_ncclAllReduce.restype = ctypes.c_int
+_c_ncclAllReduce.argtypes = [
+    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
+    ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
+]
+
+# equivalent to c declaration:
+# ncclResult_t  ncclCommDestroy(ncclComm_t comm);
+_c_ncclCommDestroy = nccl.ncclCommDestroy
+_c_ncclCommDestroy.restype = ctypes.c_int
+_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
+
+
+class NCCLCommunicator:
+
+    def __init__(
+        self,
+        backend=None,
+        init_method=None,
+        timeout=datetime.timedelta(seconds=10),
+        world_size: int = -1,
+        rank: int = -1,
+        store=None,
+        group_name: str = "",
+        pg_options=None,
+        local_rank: int = -1,
+    ):
+        if not dist.is_initialized():
+            backend = backend or "nccl"
+            assert backend == 'nccl', (
+                "only use nccl backend for starting the NCCL communicator")
+            dist.init_process_group(backend=backend,
+                                    init_method=init_method,
+                                    timeout=timeout,
+                                    world_size=world_size,
+                                    rank=rank,
+                                    store=store,
+                                    group_name=group_name,
+                                    pg_options=pg_options)
+        self.rank = dist.get_rank()
+        self.world_size = dist.get_world_size()
+        if local_rank == -1:
+            local_rank = self.rank
+        self.local_rank = local_rank
+        # don't use these args, as they can be -1
+        # use `self.rank`, `self.local_rank` and `self.world_size` instead
+        del world_size, rank, local_rank
+        torch.cuda.set_device(self.local_rank)
+        if self.rank == 0:
+            self.unique_id = ncclGetUniqueId()
+        else:
+            self.unique_id = NcclUniqueId()
+        tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
+            self.local_rank)
+        dist.broadcast(tensor, src=0)
+        byte_list = tensor.cpu().tolist()
+        for i, byte in enumerate(byte_list):
+            self.unique_id.internal[i] = byte
+        self.comm = ctypes.c_void_p()
+        result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
+                                     self.unique_id, self.rank)
+        assert result == 0
+        self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
+
+    def all_reduce(self,
+                   tensor: torch.Tensor,
+                   op: ReduceOp = ReduceOp.SUM,
+                   stream=None):
+        if stream is None:
+            stream = self.stream
+        result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
+                                  ctypes.c_void_p(tensor.data_ptr()),
+                                  tensor.numel(),
+                                  ncclDataType_t.from_torch(tensor.dtype),
+                                  ncclRedOp_t.from_torch(op), self.comm,
+                                  ctypes.c_void_p(stream.cuda_stream))
+        assert result == 0
+
+    def __del__(self):
+        # `dist` module might have been already destroyed
+        if hasattr(dist, 'destroy_process_group'):
+            dist.destroy_process_group()
+        # function might have been already destroyed
+        if _c_ncclCommDestroy is not None:
+            _c_ncclCommDestroy(self.comm)

+ 68 - 0
aphrodite/distributed/device_communicators/pynccl_utils.py

@@ -0,0 +1,68 @@
+import contextlib
+from typing import Optional
+
+import torch
+from loguru import logger
+from torch.distributed import ReduceOp
+
+try:
+    from aphrodite.distributed.device_communicators.pynccl import (
+        NCCLCommunicator,
+        ncclGetVersion,
+    )
+except Exception as e:
+    # in non-NVIDIA environments, we can't import the nccl module
+    # e.g. when running on machines with AMD GPUs
+    logger.info(f"Failed to import NCCL library: {e}")
+    logger.info("It is expected if you are not running on NVIDIA GPUs.")
+    pass
+
+comm: Optional["NCCLCommunicator"] = None
+
+
+def is_initialized() -> bool:
+    """Returns whether the NCCL backend is initialized."""
+    return comm is not None
+
+
+@contextlib.contextmanager
+def set_pynccl_stream(stream: torch.cuda.Stream):
+    """Set the cuda stream for communication"""
+    try:
+        comm.stream = stream
+        yield
+    finally:
+        pass
+
+
+def init_process_group(world_size: int,
+                       rank: int,
+                       init_method: str,
+                       local_rank: int = -1) -> None:
+    assert not is_initialized()
+    global comm
+    logger.info(f"Aphrodite is using nccl=={ncclGetVersion()}")
+    comm = NCCLCommunicator(init_method=init_method,
+                            world_size=world_size,
+                            local_rank=local_rank,
+                            rank=rank)
+
+
+def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
+    """All-reduces the input tensor across the process group."""
+    assert input_.is_cuda, f"{input_} should be a cuda tensor"
+    comm.all_reduce(input_, op)
+
+
+def destroy_process_group() -> None:
+    global comm
+    comm = None
+
+
+def get_world_size() -> int:
+    """Returns the world size."""
+    return comm.world_size
+
+
+def get_nccl_backend():
+    return comm

+ 91 - 22
aphrodite/modeling/megatron/parallel_state.py → aphrodite/distributed/parallel_state.py

@@ -5,24 +5,80 @@
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 """Tensor and pipeline parallel groups."""
 import contextlib
+from typing import Optional
 
 import torch
-
-from aphrodite.modeling.megatron import cupy_utils
+from loguru import logger
 
 # Tensor model parallel group that the current rank belongs to.
 _TENSOR_MODEL_PARALLEL_GROUP = None
 # Pipeline model parallel group that the current rank belongs to.
 _PIPELINE_MODEL_PARALLEL_GROUP = None
 
+# when people blindly call `torch.distributed.all_reduce` etc,
+# it will use this group. It is initialized with the `backend`
+# parameter of `init_distributed_environment` below.
+# Essentially, this is `torch.distributed.group.WORLD`.
+# We leave a line here to note that this is device-specific.
+# Note that this variable is not safe to use, because when users
+# call `init_distributed_environment` first, and then destroy
+# the process group themselves, this variable will keep a reference to the
+# destroyed process group, which is not useful.
+_DEVICE_WORLD_GROUP = None
+
+# duing `init_distributed_environment`, we will also initialize a
+# group with `gloo` backend, to allow direct coordination between
+# processes through the CPU.
+_CPU_WORLD_GROUP = None
+
+# In summary, after calling `init_distributed_environment`, we will
+# always have two groups: one for device-specific (and is the default)
+# and one for CPU. All processes will be part of both groups.
+
 # A list of global ranks for each pipeline group to ease calculation of the
 # source rank when broadcasting from the first or last pipeline stage.
 _PIPELINE_GLOBAL_RANKS = None
 
+_LOCAL_RANK = -1
+
+
+def get_local_rank():
+    global _LOCAL_RANK
+    return _LOCAL_RANK
+
+
+def init_distributed_environment(
+    world_size: int = -1,
+    rank: int = -1,
+    distributed_init_method: str = "env://",
+    local_rank: int = -1,
+    backend: str = "nccl",
+):
+    logger.debug(f"{world_size=} {rank=} {local_rank=} "
+                 f"{distributed_init_method=} {backend=}")
+    if not torch.distributed.is_initialized():
+        assert distributed_init_method is not None, (
+            "distributed_init_method must be provided when initializing "
+            "distributed environment")
+        # this backend is used for WORLD
+        torch.distributed.init_process_group(
+            backend=backend,
+            init_method=distributed_init_method,
+            world_size=world_size,
+            rank=rank)
+        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
+        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
+        ranks = list(range(torch.distributed.get_world_size()))
+        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
+                                                       backend="gloo")
+        global _LOCAL_RANK
+        _LOCAL_RANK = local_rank
+
 
 def initialize_model_parallel(
     tensor_model_parallel_size: int = 1,
     pipeline_model_parallel_size: int = 1,
+    backend: Optional[str] = None,
 ) -> None:
     """
     Initialize model parallel groups.
@@ -49,6 +105,8 @@ def initialize_model_parallel(
     # Get world size and rank. Ensure some consistencies.
     assert torch.distributed.is_initialized()
     world_size: int = torch.distributed.get_world_size()
+    # get the backend of _DEVICE_WORLD_GROUP
+    backend = backend or torch.distributed.get_backend()
 
     if (world_size !=
             tensor_model_parallel_size * pipeline_model_parallel_size):
@@ -70,7 +128,7 @@ def initialize_model_parallel(
     for i in range(num_tensor_model_parallel_groups):
         ranks = range(i * tensor_model_parallel_size,
                       (i + 1) * tensor_model_parallel_size)
-        group = torch.distributed.new_group(ranks)
+        group = torch.distributed.new_group(ranks, backend=backend)
         if rank in ranks:
             _TENSOR_MODEL_PARALLEL_GROUP = group
 
@@ -81,7 +139,7 @@ def initialize_model_parallel(
         "pipeline model parallel group is already initialized")
     for i in range(num_pipeline_model_parallel_groups):
         ranks = range(i, world_size, num_pipeline_model_parallel_groups)
-        group = torch.distributed.new_group(ranks)
+        group = torch.distributed.new_group(ranks, backend=backend)
         if rank in ranks:
             _PIPELINE_MODEL_PARALLEL_GROUP = group
             _PIPELINE_GLOBAL_RANKS = ranks
@@ -90,14 +148,17 @@ def initialize_model_parallel(
 def ensure_model_parallel_initialized(
     tensor_model_parallel_size: int,
     pipeline_model_parallel_size: int,
+    backend: Optional[str] = None,
 ) -> None:
     """Helper to initialize model parallel groups if they are not initialized,
     or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
     values if the model parallel groups are initialized.
     """
+    # get the backend of _DEVICE_WORLD_GROUP
+    backend = backend or torch.distributed.get_backend()
     if not model_parallel_is_initialized():
         initialize_model_parallel(tensor_model_parallel_size,
-                                  pipeline_model_parallel_size)
+                                  pipeline_model_parallel_size, backend)
         return
 
     assert (
@@ -118,6 +179,12 @@ def model_parallel_is_initialized():
             and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
 
 
+def get_cpu_world_group():
+    """Get the CPU world group."""
+    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
+    return _CPU_WORLD_GROUP
+
+
 def get_tensor_model_parallel_group():
     """Get the tensor model parallel group the caller rank belongs to."""
     assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
@@ -210,36 +277,38 @@ def destroy_model_parallel():
     _PIPELINE_MODEL_PARALLEL_GROUP = None
     global _PIPELINE_GLOBAL_RANKS
     _PIPELINE_GLOBAL_RANKS = None
-    # Destroy the cupy states if any.
-    cupy_utils.destroy_process_group()
+    from aphrodite.distributed.device_communicators import pynccl_utils
+    # Destroy the pynccl states if any.
+    pynccl_utils.destroy_process_group()
 
 
-# Whether to use cupy for nccl all reduce.
-# We use cupy for all reduce when using CUDA graph, because torch.distributed
+# Whether to use pynccl for nccl all reduce.
+# We use pynccl for all reduce when using CUDA graph, because torch.distributed
 # is not well supported by CUDA graph.
-_ENABLE_CUPY_FOR_ALL_REDUCE = False
+_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
 
 
 @contextlib.contextmanager
-def with_cupy_nccl_for_all_reduce():
-    """use CuPy nccl instead of torch.distributed for all reduce"""
+def with_pynccl_for_all_reduce():
+    """use Pynccl instead of torch.distributed for all reduce"""
+    from aphrodite.distributed.device_communicators import pynccl_utils
     tp_size = get_tensor_model_parallel_world_size()
     if tp_size == 1:
         # No-op.
-        # NOTE: We don't initialize CuPy when tp_size is 1.
+        # NOTE: We don't initialize Pynccl when tp_size is 1.
         yield
     else:
-        global _ENABLE_CUPY_FOR_ALL_REDUCE
-        old = _ENABLE_CUPY_FOR_ALL_REDUCE
-        _ENABLE_CUPY_FOR_ALL_REDUCE = True
+        global _ENABLE_PYNCCL_FOR_ALL_REDUCE
+        old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
+        _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
 
         stream = torch.cuda.current_stream()
-        with cupy_utils.set_cupy_stream(stream):
+        with pynccl_utils.set_pynccl_stream(stream):
             yield
-        _ENABLE_CUPY_FOR_ALL_REDUCE = old
+        _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
 
 
-def is_cupy_nccl_enabled_for_all_reduce():
-    """check if CuPy nccl is enabled for all reduce"""
-    global _ENABLE_CUPY_FOR_ALL_REDUCE
-    return _ENABLE_CUPY_FOR_ALL_REDUCE
+def is_pynccl_enabled_for_all_reduce():
+    """check if Pynccl is enabled for all reduce"""
+    global _ENABLE_PYNCCL_FOR_ALL_REDUCE
+    return _ENABLE_PYNCCL_FOR_ALL_REDUCE

+ 131 - 0
aphrodite/distributed/utils.py

@@ -0,0 +1,131 @@
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Adapted from
+# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+import json
+import os
+from typing import Dict, Optional, Sequence
+
+import torch
+import torch.distributed as dist
+from loguru import logger
+
+from .parallel_state import get_cpu_world_group, get_local_rank
+
+
+def ensure_divisibility(numerator, denominator):
+    """Ensure that numerator is divisible by the denominator."""
+    assert numerator % denominator == 0, "{} is not divisible by {}".format(
+        numerator, denominator)
+
+
+def divide(numerator, denominator):
+    """Ensure that numerator is divisible by the denominator and return
+    the division value."""
+    ensure_divisibility(numerator, denominator)
+    return numerator // denominator
+
+
+def split_tensor_along_last_dim(
+    tensor: torch.Tensor,
+    num_partitions: int,
+    contiguous_split_chunks: bool = False,
+) -> Sequence[torch.Tensor]:
+    """ Split a tensor along its last dimension.
+
+        Arguments:
+            tensor: input tensor.
+            num_partitions: number of partitions to split the tensor
+            contiguous_split_chunks: If True, make each chunk contiguous
+                                     in memory.
+
+        Returns:
+            A list of Tensors
+    """
+    # Get the size and dimension.
+    last_dim = tensor.dim() - 1
+    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
+    # Split.
+    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+    # NOTE: torch.split does not create contiguous tensors by default.
+    if contiguous_split_chunks:
+        return tuple(chunk.contiguous() for chunk in tensor_list)
+
+    return tensor_list
+
+
+# code partly borrowed from
+# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
+# License: MIT
+def _can_actually_p2p(idx_a, idx_b):
+    dev_i = f"cuda:{idx_a}"
+    dev_j = f"cuda:{idx_b}"
+    a = torch.randn(5, device=dev_i) + 123.0
+    b = a.to(dev_j)
+    c = b.to(dev_i)
+    return torch.all(a == c).cpu().item()
+
+
+# why do we need this cache?
+# 1. we can have runtime checks for P2P access, where every process checks
+#  P2P access to all other GPUs. Unfortunately, the test might cost many
+#  (world_size * world_size) cuda context, and reduce the memory available
+#  for the model.
+# 2. alternatively, we can have a p2p map that is generated by the master
+#  process and broadcasted to all other processes. This still requires
+#  #world_size of cuda context, belonging to the master process, on each GPU.
+# 3. we can have a cache file, that records the p2p access status. The first
+#  time the master process checks the p2p access, it will generate the cache
+#  file, at the cost of #world_size of cuda context. Later on, all processes
+#  can read the cache file to check the p2p access status without any cost of
+#  additional cuda context.
+# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
+#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,
+#  e.g. used by different aphrodite engines. The device id in the cache file is
+#  a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
+#  of visible devices in the aphrodite engine.
+_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
+
+
+def gpu_p2p_access_check(i: int, j: int) -> bool:
+    """Check if GPU i can access GPU j."""
+
+    # if the cache variable is already calculated,
+    # read from the cache instead of checking it again
+    global _gpu_p2p_access_cache
+    if _gpu_p2p_access_cache is not None:
+        return _gpu_p2p_access_cache[f"{i}->{j}"]
+
+    is_distributed = dist.is_initialized()
+
+    num_dev = torch.cuda.device_count()
+    cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+    if cuda_visible_devices is None:
+        cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
+    path = os.path.expanduser(
+        f"~/.config/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
+    )
+    os.makedirs(os.path.dirname(path), exist_ok=True)
+    if (not is_distributed or get_local_rank() == 0) \
+        and (not os.path.exists(path)):
+        # only the local master process (with local_rank == 0) can
+        #  enter this block to calculate the cache
+        logger.info(f"generating GPU P2P access cache for in {path}")
+        cache = {}
+        for _i in range(num_dev):
+            for _j in range(num_dev):
+                # on some platforms, P2P support might be buggy and we need
+                # additional checks.
+                cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
+                    _i, _j) and _can_actually_p2p(_i, _j)
+        with open(path, "w") as f:
+            json.dump(cache, f, indent=4)
+    if is_distributed:
+        cpu_world_group = get_cpu_world_group()
+        dist.barrier(cpu_world_group)
+    logger.info(f"reading GPU P2P access cache from {path}")
+    with open(path, "r") as f:
+        cache = json.load(f)
+    _gpu_p2p_access_cache = cache
+    return _gpu_p2p_access_cache[f"{i}->{j}"]

+ 31 - 0
aphrodite/endpoints/cli.py

@@ -0,0 +1,31 @@
+import argparse
+
+from aphrodite.endpoints.openai.api_server import run_server
+from aphrodite.endpoints.openai.args import make_arg_parser
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Aphrodite CLI")
+    subparsers = parser.add_subparsers()
+
+    serve_parser = subparsers.add_parser(
+        "run",
+        help="Start the Aphrodite OpenAI Compatible API server",
+        usage="aphrodite run <model_tag> [options]")
+    make_arg_parser(serve_parser)
+    # Override the `--model` optional argument, make it positional.
+    serve_parser.add_argument("model",
+                              type=str,
+                              help="The model tag or path to"
+                              " run.")
+    serve_parser.set_defaults(func=run_server)
+
+    args = parser.parse_args()
+    if hasattr(args, "func"):
+        args.func(args)
+    else:
+        parser.print_help()
+
+
+if __name__ == "__main__":
+    main()

File diff suppressed because it is too large
+ 100 - 33
aphrodite/endpoints/kobold/klite.embd


+ 21 - 5
aphrodite/endpoints/llm.py

@@ -1,5 +1,6 @@
 from typing import List, Optional, Union
 
+import torch
 from tqdm import tqdm
 from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
@@ -8,6 +9,7 @@ from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import MultiModalData
 from aphrodite.common.utils import Counter
 
 
@@ -122,6 +124,7 @@ class LLM:
         prompt_token_ids: Optional[List[List[int]]] = None,
         use_tqdm: bool = True,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> List[RequestOutput]:
         """Generates the completions for the input prompts.
 
@@ -137,6 +140,7 @@ class LLM:
                 use the tokenizer to convert the prompts to token IDs.
             use_tqdm: Whether to use tqdm to display the progress bar.
             lora_request: LoRA request to use for generation, if any.
+            multi_modal_data: Multi-modal data to use for generation, if any.
 
         Returns:
             A list of `RequestOutput` objects containing the generated
@@ -157,6 +161,9 @@ class LLM:
             # Use default sampling params.
             sampling_params = SamplingParams()
 
+        if multi_modal_data:
+            multi_modal_data.data = multi_modal_data.data.to(torch.float16)
+
         # Add requests to the engine.
         num_requests = len(prompts) if prompts is not None else len(
             prompt_token_ids)
@@ -164,10 +171,17 @@ class LLM:
             prompt = prompts[i] if prompts is not None else None
             token_ids = None if prompt_token_ids is None else prompt_token_ids[
                 i]
-            self._add_request(prompt,
-                              sampling_params,
-                              token_ids,
-                              lora_request=lora_request)
+            self._add_request(
+                prompt,
+                sampling_params,
+                token_ids,
+                lora_request=lora_request,
+                # Get ith image while maintaining the batch dim.
+                multi_modal_data=MultiModalData(
+                    type=multi_modal_data.type,
+                    data=multi_modal_data.data[i].unsqueeze(0))
+                if multi_modal_data else None,
+            )
         return self._run_engine(use_tqdm)
 
     def _add_request(
@@ -176,13 +190,15 @@ class LLM:
         sampling_params: SamplingParams,
         prompt_token_ids: Optional[List[int]],
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         request_id = str(next(self.request_counter))
         self.llm_engine.add_request(request_id,
                                     prompt,
                                     sampling_params,
                                     prompt_token_ids,
-                                    lora_request=lora_request)
+                                    lora_request=lora_request,
+                                    multi_modal_data=multi_modal_data)
 
     def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
         # Initialize tqdm.

+ 165 - 257
aphrodite/endpoints/openai/api_server.py

@@ -1,47 +1,47 @@
-import argparse
 import asyncio
-import json
-from contextlib import asynccontextmanager
-import os
 import importlib
 import inspect
-from typing import List, Tuple, AsyncGenerator, Optional
+import json
+import os
+from contextlib import asynccontextmanager
+from http import HTTPStatus
+from typing import AsyncGenerator, List, Optional, Tuple
 
-from prometheus_client import make_asgi_app
 import fastapi
 import uvicorn
-from http import HTTPStatus
-from fastapi import Request, APIRouter, Header
+from fastapi import APIRouter, Header, Request
 from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import (JSONResponse, StreamingResponse, Response,
-                               HTMLResponse)
+from fastapi.responses import (HTMLResponse, JSONResponse, Response,
+                               StreamingResponse)
 from loguru import logger
+from prometheus_client import make_asgi_app
 
 import aphrodite
-from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.endpoints.openai.protocol import (CompletionRequest,
-                                                 ChatCompletionRequest,
-                                                 ErrorResponse, Prompt,
-                                                 EmbeddingsResponse,
-                                                 EmbeddingsRequest)
+import aphrodite.endpoints.openai.embeddings as OAIembeddings
 from aphrodite.common.logger import UVICORN_LOG_CONFIG
 from aphrodite.common.outputs import RequestOutput
-from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
+from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
 from aphrodite.common.utils import random_uuid
+from aphrodite.endpoints.openai.args import make_arg_parser
+from aphrodite.endpoints.openai.protocol import (
+    ChatCompletionRequest, CompletionRequest, EmbeddingsRequest,
+    EmbeddingsResponse, ErrorResponse, KAIGenerationInputSchema, Prompt)
 from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
-from aphrodite.endpoints.openai.serving_completions import (
-    OpenAIServingCompletion)
-from aphrodite.endpoints.openai.protocol import KAIGenerationInputSchema
-from aphrodite.endpoints.openai.serving_engine import LoRA
+from aphrodite.endpoints.openai.serving_completions import \
+    OpenAIServingCompletion
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
-import aphrodite.endpoints.openai.embeddings as OAIembeddings
+from aphrodite.endpoints.openai.serving_engine import LoRA
 
 TIMEOUT_KEEP_ALIVE = 5  # seconds
 
+engine: Optional[AsyncAphrodite] = None
+engine_args: Optional[AsyncEngineArgs] = None
 openai_serving_chat: OpenAIServingChat = None
 openai_serving_completion: OpenAIServingCompletion = None
+router = APIRouter()
 kai_api = APIRouter()
 extra_api = APIRouter()
 kobold_lite_ui = ""
@@ -63,149 +63,27 @@ async def lifespan(app: fastapi.FastAPI):
     yield
 
 
-app = fastapi.FastAPI(title="Aphrodite Engine",
-                      summary="Serving language models at scale",
-                      description=("A RESTful API server compatible with "
-                                   "OpenAI and KoboldAI clients. "),
-                      lifespan=lifespan)
-
-
-class LoRAParserAction(argparse.Action):
-
-    def __call__(self, parser, namespace, values, option_string=None):
-        lora_list = []
-        for item in values:
-            name, path = item.split('=')
-            lora_list.append(LoRA(name, path))
-        setattr(namespace, self.dest, lora_list)
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description="Aphrodite OpenAI-Compatible RESTful API server.")
-    parser.add_argument("--host", type=str, default=None, help="host name")
-    parser.add_argument("--port", type=int, default=2242, help="port number")
-    parser.add_argument("--allow-credentials",
-                        action="store_true",
-                        help="allow credentials")
-    parser.add_argument("--allowed-origins",
-                        type=json.loads,
-                        default=["*"],
-                        help="allowed origins")
-    parser.add_argument("--allowed-methods",
-                        type=json.loads,
-                        default=["*"],
-                        help="allowed methods")
-    parser.add_argument("--allowed-headers",
-                        type=json.loads,
-                        default=["*"],
-                        help="allowed headers")
-    parser.add_argument(
-        "--api-keys",
-        type=str,
-        default=None,
-        help=
-        "If provided, the server will require this key to be presented in the "
-        "header.")
-    parser.add_argument(
-        "--admin-key",
-        type=str,
-        default=None,
-        help=
-        "If provided, the server will require this key to be presented in the "
-        "header for admin operations.")
-    parser.add_argument(
-        "--launch-kobold-api",
-        action="store_true",
-        help=
-        "Launch the Kobold API server in addition to the OpenAI API server.")
-    parser.add_argument("--max-length",
-                        type=int,
-                        default=256,
-                        help="The maximum length of the generated response. "
-                        "For use with Kobold Horde.")
-    parser.add_argument("--served-model-name",
-                        type=str,
-                        default=None,
-                        help="The model name used in the API. If not "
-                        "specified, the model name will be the same as "
-                        "the huggingface name.")
-    parser.add_argument(
-        "--lora-modules",
-        type=str,
-        default=None,
-        nargs='+',
-        action=LoRAParserAction,
-        help=
-        "LoRA module configurations in the format name=path. Multiple modules "
-        "can be specified.")
-    parser.add_argument("--chat-template",
-                        type=str,
-                        default=None,
-                        help="The file path to the chat template, "
-                        "or the template in single-line form "
-                        "for the specified model")
-    parser.add_argument("--response-role",
-                        type=str,
-                        default="assistant",
-                        help="The role name to return if "
-                        "`request.add_generation_prompt=true`.")
-    parser.add_argument("--ssl-keyfile",
-                        type=str,
-                        default=None,
-                        help="The file path to the SSL key file")
-    parser.add_argument("--ssl-certfile",
-                        type=str,
-                        default=None,
-                        help="The file path to the SSL cert file")
-    parser.add_argument(
-        "--root-path",
-        type=str,
-        default=None,
-        help="FastAPI root_path when app is behind a path based routing proxy")
-    parser.add_argument(
-        "--middleware",
-        type=str,
-        action="append",
-        default=[],
-        help="Additional ASGI middleware to apply to the app. "
-        "We accept multiple --middleware arguments. "
-        "The value should be an import path. "
-        "If a function is provided, Aphrodite will add it to the server using "
-        "@app.middleware('http'). "
-        "If a class is provided, Aphrodite will add it to the server using "
-        "app.add_middleware(). ")
-
-    parser = AsyncEngineArgs.add_cli_args(parser)
-    return parser.parse_args()
-
-
 # Add prometheus asgi middleware to route /metrics requests
 metrics_app = make_asgi_app()
-app.mount("/metrics/", metrics_app)
-
+router.mount("/metrics", metrics_app)
 
-@app.exception_handler(RequestValidationError)
-async def validation_exception_handler(_, exc):
-    err = openai_serving_chat.create_error_response(message=str(exc))
-    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
 
-
-@app.get("/health")
+@router.get("/health")
 async def health() -> Response:
     """Health check."""
     await openai_serving_chat.engine.check_health()
+    await openai_serving_completion.engine.check_health()
     return Response(status_code=200)
 
 
-@app.get("/v1/models")
+@router.get("/v1/models")
 async def show_available_models(x_api_key: Optional[str] = Header(None)):
     models = await openai_serving_chat.show_available_models()
     return JSONResponse(content=models.model_dump())
 
 
-@app.post("/v1/tokenize")
-@app.post("/v1/token/encode")
+@router.post("/v1/tokenize")
+@router.post("/v1/token/encode")
 async def tokenize(request: Request,
                    prompt: Prompt,
                    x_api_key: Optional[str] = Header(None)):
@@ -213,8 +91,8 @@ async def tokenize(request: Request,
     return JSONResponse(content=tokenized)
 
 
-@app.post("/v1/detokenize")
-@app.post("/v1/token/decode")
+@router.post("/v1/detokenize")
+@router.post("/v1/token/decode")
 async def detokenize(request: Request,
                      token_ids: List[int],
                      x_api_key: Optional[str] = Header(None)):
@@ -222,7 +100,7 @@ async def detokenize(request: Request,
     return JSONResponse(content=detokenized)
 
 
-@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
+@router.post("/v1/embeddings", response_model=EmbeddingsResponse)
 async def handle_embeddings(request: EmbeddingsRequest,
                             x_api_key: Optional[str] = Header(None)):
     input = request.input
@@ -237,13 +115,13 @@ async def handle_embeddings(request: EmbeddingsRequest,
     return JSONResponse(response)
 
 
-@app.get("/version", description="Fetch the Aphrodite Engine version.")
+@router.get("/version", description="Fetch the Aphrodite Engine version.")
 async def show_version(x_api_key: Optional[str] = Header(None)):
     ver = {"version": aphrodite.__version__}
     return JSONResponse(content=ver)
 
 
-@app.get("/v1/samplers")
+@router.get("/v1/samplers")
 async def show_samplers(x_api_key: Optional[str] = Header(None)):
     """Get the available samplers."""
     global sampler_json
@@ -259,7 +137,7 @@ async def show_samplers(x_api_key: Optional[str] = Header(None)):
     return sampler_json
 
 
-@app.post("/v1/lora/load")
+@router.post("/v1/lora/load")
 async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
     openai_serving_chat.add_lora(lora)
     openai_serving_completion.add_lora(lora)
@@ -270,14 +148,14 @@ async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
     return JSONResponse(content={"result": "success"})
 
 
-@app.delete("/v1/lora/unload")
+@router.delete("/v1/lora/unload")
 async def unload_lora(lora_name: str, x_api_key: Optional[str] = Header(None)):
     openai_serving_chat.remove_lora(lora_name)
     openai_serving_completion.remove_lora(lora_name)
     return JSONResponse(content={"result": "success"})
 
 
-@app.post("/v1/chat/completions")
+@router.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest,
                                  raw_request: Request,
                                  x_api_key: Optional[str] = Header(None)):
@@ -293,7 +171,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
         return JSONResponse(content=generator.model_dump())
 
 
-@app.post("/v1/completions")
+@router.post("/v1/completions")
 async def create_completion(request: CompletionRequest,
                             raw_request: Request,
                             x_api_key: Optional[str] = Header(None)):
@@ -466,9 +344,11 @@ async def count_tokens(request: Request):
     """Tokenize string and return token count"""
 
     request_dict = await request.json()
-    tokenizer_result = await openai_serving_chat.tokenize(Prompt(**request_dict))
+    tokenizer_result = await openai_serving_chat.tokenize(
+        Prompt(**request_dict))
     return JSONResponse({"value": tokenizer_result["value"]})
 
+
 @kai_api.get("/info/version")
 async def get_version():
     """Impersonate KAI"""
@@ -520,10 +400,10 @@ async def get_preloaded_story() -> JSONResponse:
 @extra_api.get("/version")
 async def get_extra_version():
     """Impersonate KoboldCpp"""
-    return JSONResponse({"result": "KoboldCpp", "version": "1.60.1"})
+    return JSONResponse({"result": "KoboldCpp", "version": "1.63"})
 
 
-@app.get("/")
+@router.get("/")
 async def get_kobold_lite_ui():
     """Serves a cached copy of the Kobold Lite UI, loading it from disk
     on demand if needed."""
@@ -542,103 +422,123 @@ async def get_kobold_lite_ui():
 
 # ============ KoboldAI API ============ #
 
-if __name__ == "__main__":
-    try:
-        args = parse_args()
-
-        if args.launch_kobold_api:
-            logger.warning(
-                "Launching Kobold API server in addition to OpenAI. "
-                "Keep in mind that the Kobold API routes are NOT "
-                "protected via the API key.")
-            app.include_router(kai_api, prefix="/api/v1")
-            app.include_router(kai_api,
-                               prefix="/api/latest",
-                               include_in_schema=False)
-            app.include_router(extra_api, prefix="/api/extra")
-
-        app.add_middleware(
-            CORSMiddleware,
-            allow_origins=args.allowed_origins,
-            allow_credentials=args.allow_credentials,
-            allow_methods=args.allowed_methods,
-            allow_headers=args.allowed_headers,
-        )
-
-        if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
-            admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
-
-            if admin_key is None:
-                logger.warning("Admin key not provided. Admin operations will "
-                               "be disabled.")
-
-            @app.middleware("http")
-            async def authentication(request: Request, call_next):
-                excluded_paths = ["/api"]
-                if any(
-                        request.url.path.startswith(path)
-                        for path in excluded_paths):
-                    return await call_next(request)
-                if not request.url.path.startswith("/v1"):
-                    return await call_next(request)
 
-                auth_header = request.headers.get("Authorization")
-                api_key_header = request.headers.get("x-api-key")
+def build_app(args):
+    app = fastapi.FastAPI(lifespan=lifespan)
+    app.include_router(router)
+    app.root_path = args.root_path
+    if args.launch_kobold_api:
+        logger.warning("Launching Kobold API server in addition to OpenAI. "
+                       "Keep in mind that the Kobold API routes are NOT "
+                       "protected via the API key.")
+        app.include_router(kai_api, prefix="/api/v1")
+        app.include_router(kai_api,
+                           prefix="/api/latest",
+                           include_in_schema=False)
+        app.include_router(extra_api, prefix="/api/extra")
+
+    app.add_middleware(
+        CORSMiddleware,
+        allow_origins=args.allowed_origins,
+        allow_credentials=args.allow_credentials,
+        allow_methods=args.allowed_methods,
+        allow_headers=args.allowed_headers,
+    )
 
-                if request.url.path.startswith("/v1/lora"):
-                    if admin_key is not None and api_key_header == admin_key:
-                        return await call_next(request)
-                    return JSONResponse(content={"error": "Unauthorized"},
-                                        status_code=401)
+    @app.exception_handler(RequestValidationError)
+    async def validation_exception_handler(_, exc):
+        err = openai_serving_completion.create_error_response(message=str(exc))
+        return JSONResponse(err.model_dump(),
+                            status_code=HTTPStatus.BAD_REQUEST)
+
+    if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
+        admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
+
+        if admin_key is None:
+            logger.warning("Admin key not provided. Admin operations will "
+                           "be disabled.")
+
+        @app.middleware("http")
+        async def authentication(request: Request, call_next):
+            excluded_paths = ["/api"]
+            if any(
+                    request.url.path.startswith(path)
+                    for path in excluded_paths):
+                return await call_next(request)
+            if not request.url.path.startswith("/v1"):
+                return await call_next(request)
 
-                if auth_header != "Bearer " + token and api_key_header != token:
-                    return JSONResponse(content={"error": "Unauthorized"},
-                                        status_code=401)
+            # Browsers may send OPTIONS requests to check CORS headers
+            # before sending the actual request. We should allow these
+            # requests to pass through without authentication.
+            # See https://github.com/PygmalionAI/aphrodite-engine/issues/434
+            if request.method == "OPTIONS":
                 return await call_next(request)
 
-        for middleware in args.middleware:
-            module_path, object_name = middleware.rsplit(".", 1)
-            imported = getattr(importlib.import_module(module_path),
-                               object_name)
-            if inspect.isclass(imported):
-                app.add_middleware(imported)
-            elif inspect.iscoroutinefunction(imported):
-                app.middleware("http")(imported)
-            else:
-                raise ValueError(f"Invalid middleware {middleware}. Must be a "
-                                 "function or a class.")
-
-        logger.debug(f"args: {args}")
-
-        if args.served_model_name is not None:
-            served_model = args.served_model_name
+            auth_header = request.headers.get("Authorization")
+            api_key_header = request.headers.get("x-api-key")
+
+            if request.url.path.startswith("/v1/lora"):
+                if admin_key is not None and api_key_header == admin_key:
+                    return await call_next(request)
+                return JSONResponse(content={"error": "Unauthorized"},
+                                    status_code=401)
+
+            if auth_header != "Bearer " + token and api_key_header != token:
+                return JSONResponse(content={"error": "Unauthorized"},
+                                    status_code=401)
+            return await call_next(request)
+
+    for middleware in args.middleware:
+        module_path, object_name = middleware.rsplit(".", 1)
+        imported = getattr(importlib.import_module(module_path), object_name)
+        if inspect.isclass(imported):
+            app.add_middleware(imported)
+        elif inspect.iscoroutinefunction(imported):
+            app.middleware("http")(imported)
         else:
-            served_model = args.model
-
-        engine_args = AsyncEngineArgs.from_cli_args(args)
-        engine = AsyncAphrodite.from_engine_args(engine_args)
-        tokenizer = get_tokenizer(
-            engine_args.tokenizer,
-            tokenizer_mode=engine_args.tokenizer_mode,
-            trust_remote_code=engine_args.trust_remote_code,
-        )
-
-        chat_template = args.chat_template
-        if chat_template is None and tokenizer.chat_template is not None:
-            chat_template = tokenizer.chat_template
-
-        openai_serving_chat = OpenAIServingChat(engine, served_model,
-                                                args.response_role,
-                                                args.lora_modules,
-                                                args.chat_template)
-        openai_serving_completion = OpenAIServingCompletion(
-            engine, served_model, args.lora_modules)
-        engine_model_config = asyncio.run(engine.get_model_config())
-
-        if args.launch_kobold_api:
-            _set_badwords(tokenizer, engine_model_config.hf_config)
-
-        app.root_path = args.root_path
+            raise ValueError(f"Invalid middleware {middleware}. "
+                             f"Must be a function or a class.")
+
+    return app
+
+
+def run_server(args):
+    app = build_app(args)
+
+    logger.debug(f"args: {args}")
+
+    global engine, engine_args, openai_serving_chat, openai_serving_completion,\
+        tokenizer, served_model
+    if args.served_model_name is not None:
+        served_model = args.served_model_name
+    else:
+        served_model = args.model
+
+    engine_args = AsyncEngineArgs.from_cli_args(args)
+    engine = AsyncAphrodite.from_engine_args(engine_args)
+    tokenizer = get_tokenizer(
+        engine_args.tokenizer,
+        tokenizer_mode=engine_args.tokenizer_mode,
+        trust_remote_code=engine_args.trust_remote_code,
+        revision=engine_args.revision,
+    )
+
+    chat_template = args.chat_template
+    if chat_template is None and tokenizer.chat_template is not None:
+        chat_template = tokenizer.chat_template
+
+    openai_serving_chat = OpenAIServingChat(engine, served_model,
+                                            args.response_role,
+                                            args.lora_modules,
+                                            args.chat_template)
+    openai_serving_completion = OpenAIServingCompletion(
+        engine, served_model, args.lora_modules)
+    engine_model_config = asyncio.run(engine.get_model_config())
+
+    if args.launch_kobold_api:
+        _set_badwords(tokenizer, engine_model_config.hf_config)
+    try:
         uvicorn.run(app,
                     host=args.host,
                     port=args.port,
@@ -648,7 +548,15 @@ if __name__ == "__main__":
                     ssl_certfile=args.ssl_certfile,
                     log_config=UVICORN_LOG_CONFIG)
     except KeyboardInterrupt:
-        logger.info("API server stopped by user. Exiting gracefully.")
+        logger.info("API server stopped by user. Exiting.")
     except asyncio.exceptions.CancelledError:
-        logger.info("API server stopped due to a cancelled request. "
-                    "Exiting gracefully.")
+        logger.info("API server stopped due to a cancelled request. Exiting.")
+
+
+if __name__ == "__main__":
+    # NOTE:
+    # This section should be in sync with aphrodite/endpoints/cli.py
+    # for CLI entrypoints.
+    parser = make_arg_parser()
+    args = parser.parse_args()
+    run_server(args)

+ 122 - 0
aphrodite/endpoints/openai/args.py

@@ -0,0 +1,122 @@
+"""
+This file contains the command line arguments for Aphrodite's
+OpenAI-compatible server. It is kept in a separate file for documentation
+purposes.
+"""
+
+import argparse
+import json
+
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.endpoints.openai.serving_engine import LoRA
+
+
+class LoRAParserAction(argparse.Action):
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        lora_list = []
+        for item in values:
+            name, path = item.split('=')
+            lora_list.append(LoRA(name, path))
+        setattr(namespace, self.dest, lora_list)
+
+
+def make_arg_parser(parser=None):
+    if parser is None:
+        parser = argparse.ArgumentParser(
+            description="Aphrodite OpenAI-Compatible RESTful API server.")
+    parser.add_argument("--host", type=str, default=None, help="host name")
+    parser.add_argument("--port", type=int, default=2242, help="port number")
+    parser.add_argument("--allow-credentials",
+                        action="store_true",
+                        help="allow credentials")
+    parser.add_argument("--allowed-origins",
+                        type=json.loads,
+                        default=["*"],
+                        help="allowed origins")
+    parser.add_argument("--allowed-methods",
+                        type=json.loads,
+                        default=["*"],
+                        help="allowed methods")
+    parser.add_argument("--allowed-headers",
+                        type=json.loads,
+                        default=["*"],
+                        help="allowed headers")
+    parser.add_argument(
+        "--api-keys",
+        type=str,
+        default=None,
+        help=
+        "If provided, the server will require this key to be presented in the "
+        "header.")
+    parser.add_argument(
+        "--admin-key",
+        type=str,
+        default=None,
+        help=
+        "If provided, the server will require this key to be presented in the "
+        "header for admin operations.")
+    parser.add_argument(
+        "--launch-kobold-api",
+        action="store_true",
+        help=
+        "Launch the Kobold API server in addition to the OpenAI API server.")
+    parser.add_argument("--max-length",
+                        type=int,
+                        default=256,
+                        help="The maximum length of the generated response. "
+                        "For use with Kobold Horde.")
+    parser.add_argument("--served-model-name",
+                        type=str,
+                        default=None,
+                        help="The model name used in the API. If not "
+                        "specified, the model name will be the same as "
+                        "the huggingface name.")
+    parser.add_argument(
+        "--lora-modules",
+        type=str,
+        default=None,
+        nargs='+',
+        action=LoRAParserAction,
+        help=
+        "LoRA module configurations in the format name=path. Multiple modules "
+        "can be specified.")
+    parser.add_argument("--chat-template",
+                        type=str,
+                        default=None,
+                        help="The file path to the chat template, "
+                        "or the template in single-line form "
+                        "for the specified model")
+    parser.add_argument("--response-role",
+                        type=str,
+                        default="assistant",
+                        help="The role name to return if "
+                        "`request.add_generation_prompt=true`.")
+    parser.add_argument("--ssl-keyfile",
+                        type=str,
+                        default=None,
+                        help="The file path to the SSL key file")
+    parser.add_argument("--ssl-certfile",
+                        type=str,
+                        default=None,
+                        help="The file path to the SSL cert file")
+    parser.add_argument(
+        "--root-path",
+        type=str,
+        default=None,
+        help="FastAPI root_path when app is behind a path based routing proxy")
+    parser.add_argument(
+        "--middleware",
+        type=str,
+        action="append",
+        default=[],
+        help="Additional ASGI middleware to apply to the app. "
+        "We accept multiple --middleware arguments. "
+        "The value should be an import path. "
+        "If a function is provided, Aphrodite will add it to the server using "
+        "@app.middleware('http'). "
+        "If a class is provided, Aphrodite will add it to the server using "
+        "app.add_middleware(). ")
+
+    parser = AsyncEngineArgs.add_cli_args(parser)
+    return parser

+ 44 - 6
aphrodite/endpoints/openai/protocol.py

@@ -3,11 +3,11 @@
 import time
 from typing import Dict, List, Literal, Optional, Union
 
-from pydantic import (AliasChoices, BaseModel, Field, model_validator,
+from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
                       root_validator)
 
-from aphrodite.common.utils import random_uuid
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.utils import random_uuid
 from aphrodite.common.logits_processor import BiasLogitsProcessor
 
 
@@ -31,7 +31,7 @@ class ModelPermission(BaseModel):
     allow_fine_tuning: bool = False
     organization: str = "*"
     group: Optional[str] = None
-    is_blocking: str = False
+    is_blocking: bool = False
 
 
 class ModelCard(BaseModel):
@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
     typical_p: Optional[float] = 1.0
     n: Optional[int] = 1
     max_tokens: Optional[int] = None
+    min_tokens: Optional[int] = 0
     seed: Optional[int] = None
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
     include_stop_str_in_output: Optional[bool] = False
@@ -109,15 +110,23 @@ class ChatCompletionRequest(BaseModel):
     guided_choice: Optional[List[str]] = None
     guided_grammar: Optional[str] = None
     response_format: Optional[ResponseFormat] = None
+    guided_decoding_backend: Optional[str] = Field(
+        default="outlines",
+        description=(
+            "If specified, will override the default guided decoding backend "
+            "of the server for this specific request. If set, must be either "
+            "'outlines' / 'lm-format-enforcer'"))
 
     def to_sampling_params(self, vocab_size: int) -> SamplingParams:
         if self.logprobs and not self.top_logprobs:
             raise ValueError("Top logprobs must be set when logprobs is.")
+        if self.top_k == 0:
+            self.top_k = -1
 
         logits_processors = []
         if self.logit_bias:
             biases = {
-                int(tok): min(100, max(float(bias), -100))
+                int(tok): max(-100, min(float(bias), 100))
                 for tok, bias in self.logit_bias.items()
                 if 0 < int(tok) < vocab_size
             }
@@ -126,6 +135,7 @@ class ChatCompletionRequest(BaseModel):
         return SamplingParams(
             n=self.n,
             max_tokens=self.max_tokens,
+            min_tokens=self.min_tokens,
             logprobs=self.top_logprobs if self.logprobs else None,
             prompt_logprobs=self.top_logprobs if self.echo else None,
             temperature=self.temperature,
@@ -182,6 +192,7 @@ class CompletionRequest(BaseModel):
     prompt: Union[List[int], List[List[int]], str, List[str]]
     suffix: Optional[str] = None
     max_tokens: Optional[int] = 16
+    min_tokens: Optional[int] = 0
     temperature: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
     tfs: Optional[float] = 1.0
@@ -226,6 +237,7 @@ class CompletionRequest(BaseModel):
     custom_token_bans: Optional[List[int]] = Field(default_factory=list)
     skip_special_tokens: Optional[bool] = True
     spaces_between_special_tokens: Optional[bool] = True
+    truncate_prompt_tokens: Optional[conint(ge=1)] = None
     grammar: Optional[str] = None
     length_penalty: Optional[float] = 1.0
     guided_json: Optional[Union[str, dict, BaseModel]] = None
@@ -233,14 +245,22 @@ class CompletionRequest(BaseModel):
     guided_choice: Optional[List[str]] = None
     guided_grammar: Optional[str] = None
     response_format: Optional[ResponseFormat] = None
+    guided_decoding_backend: Optional[str] = Field(
+        default="outlines",
+        description=(
+            "If specified, will override the default guided decoding backend "
+            "of the server for this specific request. If set, must be one of "
+            "'outlines' / 'lm-format-enforcer'"))
 
     def to_sampling_params(self, vocab_size: int) -> SamplingParams:
         echo_without_generation = self.echo and self.max_tokens == 0
+        if self.top_k == 0:
+            self.top_k = -1
 
         logits_processors = []
         if self.logit_bias:
             biases = {
-                int(tok): min(100, max(float(bias), -100))
+                int(tok): max(-100, min(float(bias), 100))
                 for tok, bias in self.logit_bias.items()
                 if 0 < int(tok) < vocab_size
             }
@@ -249,6 +269,7 @@ class CompletionRequest(BaseModel):
         return SamplingParams(
             n=self.n,
             max_tokens=self.max_tokens if not echo_without_generation else 1,
+            min_tokens=self.min_tokens,
             temperature=self.temperature,
             top_p=self.top_p,
             tfs=self.tfs,
@@ -282,6 +303,7 @@ class CompletionRequest(BaseModel):
             include_stop_str_in_output=self.include_stop_str_in_output,
             seed=self.seed,
             logits_processors=logits_processors,
+            truncate_prompt_tokens=self.truncate_prompt_tokens,
         )
 
     @model_validator(mode="before")
@@ -303,7 +325,7 @@ class LogProbs(BaseModel):
     text_offset: List[int] = Field(default_factory=list)
     token_logprobs: List[Optional[float]] = Field(default_factory=list)
     tokens: List[str] = Field(default_factory=list)
-    top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
+    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
 
 
 class CompletionResponseChoice(BaseModel):
@@ -311,6 +333,13 @@ class CompletionResponseChoice(BaseModel):
     text: str
     logprobs: Optional[LogProbs] = None
     finish_reason: Optional[Literal["stop", "length"]] = None
+    stop_reason: Union[None, int, str] = Field(
+        default=None,
+        description=(
+            "The stop string or token id that caused the completion "
+            "to stop, None if the completion finished for some other reason "
+            "including encountering the EOS token"),
+    )
 
 
 class CompletionResponse(BaseModel):
@@ -327,6 +356,13 @@ class CompletionResponseStreamChoice(BaseModel):
     text: str
     logprobs: Optional[LogProbs] = None
     finish_reason: Optional[Literal["stop", "length"]] = None
+    stop_reason: Union[None, int, str] = Field(
+        default=None,
+        description=(
+            "The stop string or token id that caused the completion "
+            "to stop, None if the completion finished for some other reason "
+            "including encountering the EOS token"),
+    )
 
 
 class CompletionStreamResponse(BaseModel):
@@ -348,6 +384,7 @@ class ChatCompletionResponseChoice(BaseModel):
     message: ChatMessage
     logprobs: Optional[LogProbs] = None
     finish_reason: Optional[Literal["stop", "length"]] = None
+    stop_reason: Union[None, int, str] = None
 
 
 class ChatCompletionResponse(BaseModel):
@@ -369,6 +406,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
     delta: DeltaMessage
     logprobs: Optional[LogProbs] = None
     finish_reason: Optional[Literal["stop", "length"]] = None
+    stop_reason: Union[None, int, str] = None
 
 
 class ChatCompletionStreamResponse(BaseModel):

+ 40 - 30
aphrodite/endpoints/openai/serving_chat.py

@@ -1,19 +1,26 @@
-import time
 import codecs
+import time
+from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
+
 from fastapi import Request
-from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
 from loguru import logger
 
+from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.utils import random_uuid
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest, ChatCompletionResponse,
-    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
-    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
-    UsageInfo)
-from aphrodite.common.outputs import RequestOutput
-from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
-from aphrodite.modeling.outlines_decoding import (
+    ChatCompletionRequest,
+    ChatCompletionResponse,
+    ChatCompletionResponseChoice,
+    ChatCompletionResponseStreamChoice,
+    ChatCompletionStreamResponse,
+    ChatMessage,
+    DeltaMessage,
+    ErrorResponse,
+    UsageInfo,
+)
+from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.modeling.guided_decoding import (
     get_guided_decoding_logits_processor)
 
 
@@ -37,9 +44,9 @@ class OpenAIServingChat(OpenAIServing):
                ChatCompletionResponse]:
         """Completion API similar to OpenAI's API.
 
-        See  https://platform.openai.com/docs/api-reference/chat/create
-        for the API specification. This API mimics the OpenAI ChatCompletion
-        API.
+        See https://platform.openai.com/docs/api-reference/chat/create
+        for the API specification. This API mimics the OpenAI
+        ChatCompletion API.
 
         NOTE: Currently we do not support the following feature:
             - function_call (Users should implement this by themselves)
@@ -60,24 +67,24 @@ class OpenAIServingChat(OpenAIServing):
 
         request_id = f"cmpl-{random_uuid()}"
         try:
-            token_ids = self._validate_prompt_and_tokenize(request,
-                                                           prompt=prompt)
+            # Tokenize/detokenize depending on prompt format (string/token list)
+            prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
+                request, prompt=prompt)
             sampling_params = request.to_sampling_params(
                 self.tokenizer.vocab_size)
             lora_request = self._maybe_get_lora(request)
             guided_decode_logits_processor = (
                 await get_guided_decoding_logits_processor(
-                    request, await self.engine.get_tokenizer()))
+                    request.guided_decoding_backend, request, await
+                    self.engine.get_tokenizer()))
             if guided_decode_logits_processor:
-                if sampling_params.logits_processors is None:
-                    sampling_params.logits_processors = []
                 sampling_params.logits_processors.append(
                     guided_decode_logits_processor)
         except ValueError as e:
             return self.create_error_response(str(e))
 
-        result_generator = self.engine.generate(prompt, sampling_params,
-                                                request_id, token_ids,
+        result_generator = self.engine.generate(prompt_text, sampling_params,
+                                                request_id, prompt_ids,
                                                 lora_request)
         # Streaming response
         if request.stream:
@@ -103,7 +110,7 @@ class OpenAIServingChat(OpenAIServing):
     ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
 
         model_name = request.model
-        created_time = int(time.monotonic())
+        created_time = int(time.time())
         chunk_object_type = "chat.completion.chunk"
         first_iteration = True
 
@@ -136,8 +143,8 @@ class OpenAIServingChat(OpenAIServing):
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
 
-                    # Send response to echo the input portion of the last
-                    # message
+                    # Send response to echo the input portion of the
+                    # last message
                     if request.echo:
                         last_msg_content = ""
                         if request.messages and isinstance(
@@ -149,11 +156,12 @@ class OpenAIServingChat(OpenAIServing):
 
                         if last_msg_content:
                             for i in range(request.n):
-                                choice_data = ChatCompletionResponseStreamChoice(  # noqa
-                                    index=i,
-                                    delta=DeltaMessage(
-                                        content=last_msg_content),
-                                    finish_reason=None)
+                                choice_data = (
+                                    ChatCompletionResponseStreamChoice(
+                                        index=i,
+                                        delta=DeltaMessage(
+                                            content=last_msg_content),
+                                        finish_reason=None))
                                 chunk = ChatCompletionStreamResponse(
                                     id=request_id,
                                     object=chunk_object_type,
@@ -217,7 +225,8 @@ class OpenAIServingChat(OpenAIServing):
                             index=i,
                             delta=DeltaMessage(content=delta_text),
                             logprobs=logprobs,
-                            finish_reason=output.finish_reason)
+                            finish_reason=output.finish_reason,
+                            stop_reason=output.stop_reason)
                         chunk = ChatCompletionStreamResponse(
                             id=request_id,
                             object=chunk_object_type,
@@ -243,7 +252,7 @@ class OpenAIServingChat(OpenAIServing):
             request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
 
         model_name = request.model
-        created_time = int(time.monotonic())
+        created_time = int(time.time())
         final_res: RequestOutput = None
 
         async for res in result_generator:
@@ -275,6 +284,7 @@ class OpenAIServingChat(OpenAIServing):
                 message=ChatMessage(role=role, content=output.text),
                 logprobs=logprobs,
                 finish_reason=output.finish_reason,
+                stop_reason=output.stop_reason,
             )
             choices.append(choice_data)
 

+ 62 - 91
aphrodite/endpoints/openai/serving_completions.py

@@ -1,11 +1,18 @@
-import asyncio
 import time
+from typing import (
+    AsyncGenerator,
+    AsyncIterator,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+)
+
 from fastapi import Request
-from typing import (AsyncGenerator, AsyncIterator, Callable, List, Optional,
-                    Dict, Tuple)
 
-from aphrodite.common.utils import random_uuid
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.utils import merge_async_iterators, random_uuid
 from aphrodite.endpoints.openai.protocol import (
     CompletionRequest,
     CompletionResponse,
@@ -15,9 +22,9 @@ from aphrodite.endpoints.openai.protocol import (
     LogProbs,
     UsageInfo,
 )
-from aphrodite.common.outputs import RequestOutput
-from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
-from aphrodite.modeling.outlines_decoding import (
+from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.modeling.guided_decoding import (
     get_guided_decoding_logits_processor)
 
 TypeTokenIDs = List[int]
@@ -44,47 +51,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
             prompt_is_tokens = True
             prompts = prompt  # case 4: array of token arrays
         else:
-            raise ValueError(
-                "prompt must be a string, array of strings, array of tokens, "
-                "or array of token arrays")
+            raise ValueError("prompt must be a string, array of strings, "
+                             "array of tokens, or array of token arrays")
     return prompt_is_tokens, prompts
 
 
-def merge_async_iterators(*iterators):
-    """Merge multiple asynchronous iterators into a single iterator.
-
-    This method handle the case where some iterators finish before others.
-    When it yields, it yields a tuple (i, item) where i is the index of the
-    iterator that yields the item.
-    """
-    queue = asyncio.Queue()
-
-    finished = [False] * len(iterators)
-
-    async def producer(i, iterator):
-        try:
-            async for item in iterator:
-                await queue.put((i, item))
-        except Exception as e:
-            await queue.put(e)
-        finished[i] = True
-
-    _tasks = [
-        asyncio.create_task(producer(i, iterator))
-        for i, iterator in enumerate(iterators)
-    ]
-
-    async def consumer():
-        while not all(finished) or not queue.empty():
-            item = await queue.get()
-            if isinstance(item, Exception):
-                raise item
-            yield item
-        await asyncio.gather(*_tasks)
-
-    return consumer()
-
-
 class OpenAIServingCompletion(OpenAIServing):
 
     def __init__(self,
@@ -117,7 +88,7 @@ class OpenAIServingCompletion(OpenAIServing):
 
         model_name = request.model
         request_id = f"cmpl-{random_uuid()}"
-        created_time = int(time.monotonic())
+        created_time = int(time.time())
 
         # Schedule the request and get the result generator.
         generators = []
@@ -125,39 +96,49 @@ class OpenAIServingCompletion(OpenAIServing):
             sampling_params = request.to_sampling_params(
                 self.tokenizer.vocab_size)
             lora_request = self._maybe_get_lora(request)
+            decoding_config = self.engine.engine.decoding_config
+            guided_decoding_backend = request.guided_decoding_backend \
+                or decoding_config.guided_decoding_backend
             guided_decode_logit_processor = (
                 await get_guided_decoding_logits_processor(
-                    request, await self.engine.get_tokenizer()))
+                    guided_decoding_backend, request, await
+                    self.engine.get_tokenizer()))
             if guided_decode_logit_processor is not None:
-                if sampling_params.logits_processors is None:
-                    sampling_params.logits_processors = []
                 sampling_params.logits_processors.append(
                     guided_decode_logit_processor)
             prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
 
             for i, prompt in enumerate(prompts):
                 if prompt_is_tokens:
-                    input_ids = self._validate_prompt_and_tokenize(
-                        request, prompt_ids=prompt)
+                    prompt_formats = self._validate_prompt_and_tokenize(
+                        request,
+                        prompt_ids=prompt,
+                        truncate_prompt_tokens=sampling_params.
+                        truncate_prompt_tokens)
                 else:
-                    input_ids = self._validate_prompt_and_tokenize(
-                        request, prompt=prompt)
+                    prompt_formats = self._validate_prompt_and_tokenize(
+                        request,
+                        prompt=prompt,
+                        truncate_prompt_tokens=sampling_params.
+                        truncate_prompt_tokens)
+                prompt_ids, prompt_text = prompt_formats
 
                 generators.append(
-                    self.engine.generate(prompt,
+                    self.engine.generate(prompt_text,
                                          sampling_params,
                                          f"{request_id}-{i}",
-                                         prompt_token_ids=input_ids,
+                                         prompt_token_ids=prompt_ids,
                                          lora_request=lora_request))
         except ValueError as e:
+            # TODO: Use a specific-specific Validation Error
             return self.create_error_response(str(e))
 
         result_generator: AsyncIterator[Tuple[
             int, RequestOutput]] = merge_async_iterators(*generators)
 
         # Similar to the OpenAI API, when n != best_of, we do not stream the
-        # results. In addition, we do not stream the results when use beam
-        # search.
+        # results. In addition, we do not stream the results when use
+        # beam search.
         stream = (request.stream
                   and (request.best_of is None or request.n == request.best_of)
                   and not request.use_beam_search)
@@ -224,8 +205,8 @@ class OpenAIServingCompletion(OpenAIServing):
 
                 for output in res.outputs:
                     i = output.index + prompt_idx * request.n
-                    # TODO: optimize the performance by avoiding full text
-                    # O(n^2) sending.
+                    # TODO: optimize the performance by avoiding full
+                    # text O(n^2) sending.
 
                     if request.echo and request.max_tokens == 0:
                         # only return the prompt
@@ -251,8 +232,6 @@ class OpenAIServingCompletion(OpenAIServing):
                             i]:] if output.logprobs else None
 
                     if request.logprobs is not None:
-                        assert top_logprobs is not None, "top_logprobs must " \
-                            "be provided when logprobs is requested"
                         logprobs = self._create_logprobs(
                             token_ids=delta_token_ids,
                             top_logprobs=top_logprobs,
@@ -265,6 +244,17 @@ class OpenAIServingCompletion(OpenAIServing):
                     previous_texts[i] = output.text
                     previous_num_tokens[i] = len(output.token_ids)
                     finish_reason = output.finish_reason
+                    stop_reason = output.stop_reason
+                    if output.finish_reason is not None:  # return final usage
+                        prompt_tokens = len(res.prompt_token_ids)
+                        completion_tokens = len(output.token_ids)
+                        final_usage = UsageInfo(
+                            prompt_tokens=prompt_tokens,
+                            completion_tokens=completion_tokens,
+                            total_tokens=prompt_tokens + completion_tokens,
+                        )
+                    else:
+                        final_usage = None
                     response_json = CompletionStreamResponse(
                         id=request_id,
                         created=created_time,
@@ -275,40 +265,16 @@ class OpenAIServingCompletion(OpenAIServing):
                                 text=delta_text,
                                 logprobs=logprobs,
                                 finish_reason=finish_reason,
+                                stop_reason=stop_reason,
                             )
-                        ]).model_dump_json()
+                        ],
+                        usage=final_usage,
+                    ).model_dump_json(exclude_unset=True)
                     yield f"data: {response_json}\n\n"
-
-                    if output.finish_reason is not None:  # return final usage
-                        logprobs = LogProbs(
-                        ) if request.logprobs is not None else None
-                        prompt_tokens = len(res.prompt_token_ids)
-                        completion_tokens = len(output.token_ids)
-                        final_usage = UsageInfo(
-                            prompt_tokens=prompt_tokens,
-                            completion_tokens=completion_tokens,
-                            total_tokens=prompt_tokens + completion_tokens,
-                        )
-                        response_json = CompletionStreamResponse(
-                            id=request_id,
-                            created=created_time,
-                            model=model_name,
-                            choices=[
-                                CompletionResponseStreamChoice(
-                                    index=i,
-                                    text="",
-                                    logprobs=logprobs,
-                                    finish_reason=output.finish_reason,
-                                )
-                            ],
-                            usage=final_usage,
-                        ).model_dump_json()
-                        yield f"data: {response_json}\n\n"
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             data = self.create_streaming_error_response(str(e))
             yield f"data: {data}\n\n"
-
         yield "data: [DONE]\n\n"
 
     def request_output_to_completion_response(
@@ -335,7 +301,8 @@ class OpenAIServingCompletion(OpenAIServing):
                     output_text = prompt_text
                 elif request.echo and request.max_tokens > 0:
                     token_ids = prompt_token_ids + output.token_ids
-                    top_logprobs = prompt_logprobs + output.logprobs
+                    top_logprobs = (prompt_logprobs + output.logprobs
+                                    if request.logprobs else None)
                     output_text = prompt_text + output.text
                 else:
                     token_ids = output.token_ids
@@ -343,6 +310,9 @@ class OpenAIServingCompletion(OpenAIServing):
                     output_text = output.text
 
                 if request.logprobs is not None:
+                    assert top_logprobs is not None, (
+                        "top_logprobs must be provided when logprobs "
+                        "is requested")
                     logprobs = self._create_logprobs(
                         token_ids=token_ids,
                         top_logprobs=top_logprobs,
@@ -356,6 +326,7 @@ class OpenAIServingCompletion(OpenAIServing):
                     text=output_text,
                     logprobs=logprobs,
                     finish_reason=output.finish_reason,
+                    stop_reason=output.stop_reason,
                 )
                 choices.append(choice_data)
 

+ 71 - 42
aphrodite/endpoints/openai/serving_engine.py

@@ -2,19 +2,25 @@ import asyncio
 import json
 from dataclasses import dataclass
 from http import HTTPStatus
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Tuple, Union
 
 from loguru import logger
+from pydantic import conint
 
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
+from aphrodite.common.sequence import Logprob
+from aphrodite.endpoints.openai.protocol import (
+    ChatCompletionRequest,
+    CompletionRequest,
+    ErrorResponse,
+    LogProbs,
+    ModelCard,
+    ModelList,
+    ModelPermission,
+    Prompt,
+)
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.endpoints.openai.protocol import (CompletionRequest,
-                                                 ChatCompletionRequest,
-                                                 ErrorResponse, LogProbs,
-                                                 ModelCard, ModelList,
-                                                 ModelPermission, Prompt)
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import Logprob
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 
 @dataclass
@@ -50,11 +56,12 @@ class OpenAIServing:
         except RuntimeError:
             event_loop = None
 
-        if event_loop is not None and event_loop.is_running(
-        ):  # If the current is instanced by Ray Serve, there is already
-            # a running event loop
+        if event_loop is not None and event_loop.is_running():
+            # If the current is instanced by Ray Serve,
+            # there is already a running event loop
             event_loop.create_task(self._post_init())
-        else:  # When using single Aphrodite without engine_use_ray
+        else:
+            # When using single Aphrodite without engine_use_ray
             asyncio.run(self._post_init())
 
     async def _post_init(self):
@@ -65,7 +72,9 @@ class OpenAIServing:
         self.tokenizer = get_tokenizer(
             engine_model_config.tokenizer,
             tokenizer_mode=engine_model_config.tokenizer_mode,
-            trust_remote_code=engine_model_config.trust_remote_code)
+            trust_remote_code=engine_model_config.trust_remote_code,
+            revision=engine_model_config.revision,
+            truncation_side="left")
 
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
@@ -107,33 +116,40 @@ class OpenAIServing:
         last_token_len = 0
         if num_output_top_logprobs:
             logprobs.top_logprobs = []
+
         for i, token_id in enumerate(token_ids):
             step_top_logprobs = top_logprobs[i]
-            if step_top_logprobs is not None:
-                token_logprob = step_top_logprobs[token_id].logprob
+            if step_top_logprobs is None:
+                token = self.tokenizer.decode(token_id)
+                logprobs.tokens.append(token)
+                logprobs.token_logprobs.append(None)
+                logprobs.top_logprobs.append(None)
             else:
-                token_logprob = None
-            token = step_top_logprobs[token_id].decoded_token
-            logprobs.tokens.append(token)
-            logprobs.token_logprobs.append(token_logprob)
+                token_logprob = step_top_logprobs[token_id].logprob
+                token = step_top_logprobs[token_id].decoded_token
+                logprobs.tokens.append(token)
+                logprobs.token_logprobs.append(token_logprob)
+
+                if num_output_top_logprobs:
+                    logprobs.top_logprobs.append({
+                        p.decoded_token: p.logprob
+                        for i, p in step_top_logprobs.items()
+                    } if step_top_logprobs else None)
+
+            if logprobs.top_logprobs:
+                logprobs.top_logprobs = [{
+                    k: v if v > -1000 else -1000
+                    for k, v in top_logprob.items()
+                } for top_logprob in logprobs.top_logprobs
+                                         if top_logprob is not None
+                                         ]  # noqa: E501
+
             if len(logprobs.text_offset) == 0:
                 logprobs.text_offset.append(initial_text_offset)
             else:
                 logprobs.text_offset.append(logprobs.text_offset[-1] +
                                             last_token_len)
             last_token_len = len(token)
-
-            if num_output_top_logprobs:
-                logprobs.top_logprobs.append({
-                    p.decoded_token: p.logprob
-                    for i, p in step_top_logprobs.items()
-                } if step_top_logprobs else None)
-
-        logprobs.top_logprobs = [{
-            k: v if v > -1000 else -1000
-            for k, v in top_logprob.items()
-        } for top_logprob in logprobs.top_logprobs if top_logprob is not None]
-
         return logprobs
 
     def create_error_response(
@@ -196,18 +212,31 @@ class OpenAIServing:
         raise ValueError("The model `{request.model}` does not exist.")
 
     def _validate_prompt_and_tokenize(
-            self,
-            request: Union[ChatCompletionRequest, CompletionRequest],
-            prompt: Optional[str] = None,
-            prompt_ids: Optional[List[int]] = None) -> List[int]:
+        self,
+        request: Union[ChatCompletionRequest, CompletionRequest],
+        prompt: Optional[str] = None,
+        prompt_ids: Optional[List[int]] = None,
+        truncate_prompt_tokens: Optional[conint(ge=1)] = None
+    ) -> Tuple[List[int], str]:
         if not (prompt or prompt_ids):
             raise ValueError("Either prompt or prompt_ids should be provided.")
         if (prompt and prompt_ids):
             raise ValueError(
                 "Only one of prompt or prompt_ids should be provided.")
 
-        input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
-            prompt).input_ids
+        if prompt_ids is None:
+            tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
+                "truncation": True,
+                "max_length": truncate_prompt_tokens,
+            }
+            input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
+        elif truncate_prompt_tokens is not None:
+            input_ids = prompt_ids[-truncate_prompt_tokens:]
+        else:
+            input_ids = prompt_ids
+
+        input_text = prompt if prompt is not None else self.tokenizer.decode(
+            prompt_ids)
         token_num = len(input_ids)
 
         if request.max_tokens is None:
@@ -215,11 +244,11 @@ class OpenAIServing:
 
         if token_num + request.max_tokens > self.max_model_len:
             raise ValueError(
-                f"This model's maximum context length is {self.max_model_len} "
-                "tokens. "
-                f"However, you requested {request.max_tokens + token_num} "
-                f"tokens ({token_num} in the messages, "
+                f"This model's maximum context length is "
+                f"{self.max_model_len} tokens. However, you requested "
+                f"{request.max_tokens + token_num} tokens "
+                f"({token_num} in the messages, "
                 f"{request.max_tokens} in the completion). "
                 f"Please reduce the length of the messages or completion.", )
         else:
-            return input_ids
+            return input_ids, input_text

+ 196 - 412
aphrodite/engine/aphrodite_engine.py

@@ -1,44 +1,49 @@
 import time
-from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
+from typing import Iterable, List, Optional, Type, Union
+
 from loguru import logger
-from transformers import PreTrainedTokenizer
+from transformers import GenerationConfig, PreTrainedTokenizer
 
 import aphrodite
-from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (
-    CacheConfig,
-    DeviceConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-)
-from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
-from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.executor.executor_base import ExecutorBase
-from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import (initialize_ray_cluster)
+from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
+                                     SchedulerConfig, SpeculativeConfig,
+                                     VisionLanguageConfig)
+from aphrodite.common.logger import setup_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (
-    Logprob,
-    SamplerOutput,
-    Sequence,
-    SequenceGroup,
-    SequenceGroupOutput,
-    SequenceOutput,
-    SequenceStatus,
-)
-from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally)
+from aphrodite.common.sequence import (MultiModalData, SamplerOutput, Sequence,
+                                       SequenceGroup, SequenceStage)
+from aphrodite.common.utils import Counter
+from aphrodite.engine.args_tools import EngineArgs
+from aphrodite.engine.metrics import StatLogger, Stats
+from aphrodite.engine.output_processor.interfaces import \
+    SequenceGroupOutputProcessor
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.engine.output_processor.util import \
+    create_output_by_sequence_group
+from aphrodite.engine.ray_tools import initialize_ray_cluster
+from aphrodite.executor.executor_base import ExecutorBase
+from aphrodite.lora.request import LoRARequest
+from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
+from aphrodite.transformers_utils.detokenizer import Detokenizer
 from aphrodite.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                           get_tokenizer_group)
-from aphrodite.common.utils import (
-    Counter, )
-from aphrodite.common.logger import setup_logger
 
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 
 
+def _load_generation_config_dict(model_config: ModelConfig):
+    try:
+        return GenerationConfig.from_pretrained(
+            model_config.model,
+            revision=model_config.revision,
+        ).to_diff_dict()
+    except OSError:
+        # Not found.
+        return {}
+
+
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
 
@@ -62,7 +67,11 @@ class AphroditeEngine:
         parallel_config: The configuration related to distributed execution.
         scheduler_config: The configuration related to the request scheduler.
         device_config: The configuration related to the device.
-        lora_config: The configuration related to LoRA.
+        lora_config (Optional): The configuration related to serving multi-LoRA.
+        vision_language_config (Optional): The configuration related to vision
+            language models.
+        speculative_config (Optional): The configuration related to speculative
+            decoding.
         executor_class: The model executor class for managing distributed
             execution.
         log_stats: Whether to log statistics.
@@ -76,6 +85,9 @@ class AphroditeEngine:
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig],
+        speculative_config: Optional[SpeculativeConfig],
+        decoding_config: Optional[DecodingConfig],
         executor_class: Type[ExecutorBase],
         log_stats: bool,
     ) -> None:
@@ -83,6 +95,7 @@ class AphroditeEngine:
             f"Initializing the Aphrodite Engine (v{aphrodite.__version__}) "
             "with the following config:\n"
             f"Model = {model_config.model!r}\n"
+            f"Speculative Config = {speculative_config!r}\n"
             f"DataType = {model_config.dtype}\n"
             f"Model Load Format = {model_config.load_format}\n"
             f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
@@ -92,25 +105,41 @@ class AphroditeEngine:
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
-            # f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
-            f"Device = {device_config.device}")
+            f"KV Cache Params Path = {model_config.quantization_param_path}\n"
+            f"Device = {device_config.device}\n"
+            f"Guided Decoding Backend = {decoding_config!r}\n")
         # TODO: Print more configs in debug mode.
 
         self.model_config = model_config
         self.cache_config = cache_config
         self.lora_config = lora_config
+        self.vision_language_config = vision_language_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
         self.device_config = device_config
+        self.speculative_config = speculative_config
+        self.decoding_config = decoding_config or DecodingConfig()
         self.log_stats = log_stats
         self._verify_args()
 
         self._init_tokenizer()
+        self.detokenizer = Detokenizer(self.tokenizer)
         self.seq_counter = Counter()
+        self.generation_config_fields = _load_generation_config_dict(
+            model_config)
+
+        self.model_executor = executor_class(
+            model_config=model_config,
+            cache_config=cache_config,
+            parallel_config=parallel_config,
+            scheduler_config=scheduler_config,
+            device_config=device_config,
+            lora_config=lora_config,
+            vision_language_config=vision_language_config,
+            speculative_config=speculative_config,
+        )
 
-        self.model_executor = executor_class(model_config, cache_config,
-                                             parallel_config, scheduler_config,
-                                             device_config, lora_config)
+        self._initialize_kv_caches()
 
         # Ping the tokenizer to ensure it is loaded if
         # it runs on a separate process.
@@ -129,26 +158,65 @@ class AphroditeEngine:
             )
             self.stat_logger.info("cache_config", self.cache_config)
 
+        # Create sequence output processor, e.g. for beam search or
+        # speculative decoding.
+        self.output_processor = (
+            SequenceGroupOutputProcessor.create_output_processor(
+                self.scheduler_config,
+                self.detokenizer,
+                self.scheduler,
+                self.seq_counter,
+                self.get_tokenizer_for_seq,
+                stop_checker=StopChecker(
+                    self.scheduler_config.max_model_len,
+                    self.get_tokenizer_for_seq,
+                ),
+            ))
+
+    def _initialize_kv_caches(self) -> None:
+        """Initialize the KV cache in the worker(s).
+        The workers will determine the number of blocks in both the GPU cache
+        and the swap CPU cache.
+        """
+        num_gpu_blocks, num_cpu_blocks = (
+            self.model_executor.determine_num_available_blocks())
+
+        if self.cache_config.num_gpu_blocks_override is not None:
+            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
+            logger.info(f"Overriding {num_gpu_blocks=} with "
+                        f"{num_gpu_blocks_override=}")
+            num_gpu_blocks = num_gpu_blocks_override
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
     @classmethod
     def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
         """Creates an LLM engine from the engine arguments."""
         # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
+        engine_config = engine_args.create_engine_config()
 
         # Initialize the cluster and specify the executor class.
-        if parallel_config.worker_use_ray:
-            initialize_ray_cluster(parallel_config)
+        if engine_config.device_config.device_type == "neuron":
+            from aphrodite.executor.neuron_executor import NeuronExecutor
+            executor_class = NeuronExecutor
+        elif engine_config.device_config.device_type == "cpu":
+            from aphrodite.executor.cpu_executor import CPUExecutor
+            executor_class = CPUExecutor
+        elif engine_config.parallel_config.worker_use_ray:
+            initialize_ray_cluster(engine_config.parallel_config)
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
             executor_class = RayGPUExecutor
         else:
-            assert parallel_config.world_size == 1, (
+            assert engine_config.parallel_config.world_size == 1, (
                 "Ray is required if parallel_config.world_size > 1.")
             from aphrodite.executor.gpu_executor import GPUExecutor
             executor_class = GPUExecutor
 
         # Create the LLM engine.
-        engine = cls(*engine_configs,
+        engine = cls(**engine_config.to_dict(),
                      executor_class=executor_class,
                      log_stats=not engine_args.disable_log_stats)
         return engine
@@ -159,7 +227,7 @@ class AphroditeEngine:
         raise RuntimeError("AphroditeEngine should not be pickled!")
 
     def get_tokenizer(self) -> "PreTrainedTokenizer":
-        return self.tokenizer.get_lora_tokenizer()
+        return self.tokenizer.get_lora_tokenizer(None)
 
     def get_tokenizer_for_seq(self,
                               sequence: Sequence) -> "PreTrainedTokenizer":
@@ -173,12 +241,19 @@ class AphroditeEngine:
             max_input_length=None,
             tokenizer_mode=self.model_config.tokenizer_mode,
             trust_remote_code=self.model_config.trust_remote_code,
-            revision=self.model_config.tokenizer_revision,
-        )
+            revision=self.model_config.revision)
         init_kwargs.update(tokenizer_init_kwargs)
         self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
             self.parallel_config.tokenizer_pool_config, **init_kwargs)
 
+        if len(self.get_tokenizer()) != self.model_config.get_vocab_size():
+            logger.warning(
+                f"The tokenizer's vocabulary size {len(self.get_tokenizer())}"
+                f" does not match the model's vocabulary size "
+                f"{self.model_config.get_vocab_size()}.")
+        self.model_config.hf_config.tokenizer_vocab_size = len(
+            self.get_tokenizer())
+
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
@@ -209,6 +284,7 @@ class AphroditeEngine:
         prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         """Add a request to the engine's request pool.
 
@@ -225,6 +301,7 @@ class AphroditeEngine:
                 use the tokenizer to convert the prompts to token IDs.
             arrival_time: The arrival time of the request. If None, we use
                 the current monotonic time.
+            multi_modal_data: The multimodal data for the request.
 
         Details:
             - Set arrival_time to the current time if it is None.
@@ -287,10 +364,15 @@ class AphroditeEngine:
         # Defensive copy of SamplingParams, which are used by the sampler,
         # this doesn't deep-copy LogitsProcessor objects
         sampling_params = sampling_params.clone()
+        # Inject the eos token id into the sampling_params to support min_tokens
+        # processing
+        sampling_params.eos_token_id = seq.eos_token_id
+        sampling_params.update_from_generation_config(
+            self.generation_config_fields)
 
         # Create the sequence group.
         seq_group = SequenceGroup(request_id, [seq], sampling_params,
-                                  arrival_time, lora_request)
+                                  arrival_time, lora_request, multi_modal_data)
 
         # Add the sequence group to the scheduler.
         self.scheduler.add_seq_group(seq_group)
@@ -326,276 +408,46 @@ class AphroditeEngine:
         """Returns True if there are unfinished requests."""
         return self.scheduler.has_unfinished_seqs()
 
-    def _check_beam_search_early_stopping(
-        self,
-        early_stopping: Union[bool, str],
-        sampling_params: SamplingParams,
-        best_running_seq: Sequence,
-        current_worst_seq: Sequence,
-    ) -> bool:
-        assert sampling_params.use_beam_search
-        length_penalty = sampling_params.length_penalty
-        if early_stopping is True:
-            return True
-
-        current_worst_score = current_worst_seq.get_beam_search_score(
-            length_penalty=length_penalty,
-            eos_token_id=current_worst_seq.eos_token_id,
-        )
-        if early_stopping is False:
-            highest_attainable_score = best_running_seq.get_beam_search_score(
-                length_penalty=length_penalty,
-                eos_token_id=best_running_seq.eos_token_id,
-            )
-        else:
-            assert early_stopping == "never"
-            if length_penalty > 0.0:
-                # If length_penalty > 0.0, beam search will prefer longer
-                # sequences. The highest attainable score calculation is
-                # based on the longest possible sequence length in this case.
-                max_possible_length = max(
-                    best_running_seq.get_prompt_len() +
-                    sampling_params.max_tokens,
-                    self.scheduler_config.max_model_len,
-                )
-                highest_attainable_score = (
-                    best_running_seq.get_beam_search_score(
-                        length_penalty=length_penalty,
-                        eos_token_id=best_running_seq.eos_token_id,
-                        seq_len=max_possible_length,
-                    ))
-            else:
-                # Otherwise, beam search will prefer shorter sequences. The
-                # highest attainable score calculation is based on the current
-                # sequence length.
-                highest_attainable_score = (
-                    best_running_seq.get_beam_search_score(
-                        length_penalty=length_penalty,
-                        eos_token_id=best_running_seq.eos_token_id,
-                    ))
-        return current_worst_score >= highest_attainable_score
-
-    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
-                                        outputs: SequenceGroupOutput) -> None:
-        # Process prompt logprobs
-        prompt_logprobs = outputs.prompt_logprobs
-        if prompt_logprobs is not None:
-            # We can pick any sequence for the prompt.
-            seq = next(iter(seq_group.seqs_dict.values()))
-            all_token_ids = seq.get_token_ids()
-            for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
-                self._decode_logprobs(
-                    seq,
-                    seq_group.sampling_params,
-                    prompt_logprobs_for_token,
-                    all_token_ids[:i],
-                )
-            seq_group.prompt_logprobs = prompt_logprobs
-
-        # Process samples
-        samples = outputs.samples
-        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
-        existing_finished_seqs = seq_group.get_finished_seqs()
-        parent_child_dict = {
-            parent_seq.seq_id: []
-            for parent_seq in parent_seqs
-        }
-        for sample in samples:
-            parent_child_dict[sample.parent_seq_id].append(sample)
-        # List of (child, parent)
-        child_seqs: List[Tuple[Sequence, Sequence]] = []
-
-        # Process the child samples for each parent sequence
-        for parent in parent_seqs:
-            child_samples: List[SequenceOutput] = parent_child_dict[
-                parent.seq_id]
-            if len(child_samples) == 0:
-                # This parent sequence has no children samples. Remove
-                # the parent sequence from the sequence group since it will
-                # not be used in the future iterations.
-                parent.status = SequenceStatus.FINISHED_ABORTED
-                seq_group.remove(parent.seq_id)
-                self.scheduler.free_seq(parent)
-                continue
-            # Fork the parent sequence if there are multiple child samples.
-            for child_sample in child_samples[:-1]:
-                new_child_seq_id = next(self.seq_counter)
-                child = parent.fork(new_child_seq_id)
-                child.append_token_id(child_sample.output_token,
-                                      child_sample.logprobs)
-                child.persistent_data = child_sample.persistent_data
-                child_seqs.append((child, parent))
-            # Continue the parent sequence for the last child sample.
-            # We reuse the parent sequence here to reduce redundant memory
-            # copies, especially when using non-beam search sampling methods.
-            last_child_sample = child_samples[-1]
-            parent.append_token_id(last_child_sample.output_token,
-                                   last_child_sample.logprobs)
-            parent.persistent_data = last_child_sample.persistent_data
-            child_seqs.append((parent, parent))
-
-        for seq, _ in child_seqs:
-            self._decode_sequence(seq, seq_group.sampling_params)
-            self._check_stop(seq, seq_group.sampling_params)
-
-        # Non-beam search case
-        if not seq_group.sampling_params.use_beam_search:
-            # For newly created child sequences, add them to the sequence group
-            # and fork them in block manager if they are not finished.
-            for seq, parent in child_seqs:
-                if seq is not parent:
-                    seq_group.add(seq)
-                    if not seq.is_finished():
-                        self.scheduler.fork_seq(parent, seq)
-
-            # Free the finished and selected parent sequences' memory in block
-            # manager. Keep them in the sequence group as candidate output.
-            # NOTE: we need to fork the new sequences before freeing the
-            # old sequences.
-            for seq, parent in child_seqs:
-                if seq is parent and seq.is_finished():
-                    self.scheduler.free_seq(seq)
-            return
-
-        # Beam search case
-        # Select the child sequences to keep in the sequence group.
-        selected_child_seqs = []
-        unselected_child_seqs = []
-        beam_width = seq_group.sampling_params.best_of
-        length_penalty = seq_group.sampling_params.length_penalty
-
-        # Select the newly finished sequences with the highest scores
-        # to replace existing finished sequences.
-        # Tuple of (seq, parent, is_new)
-        existing_finished_seqs = [(seq, None, False)
-                                  for seq in existing_finished_seqs]
-        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
-                             if seq.is_finished()]
-        all_finished_seqs = existing_finished_seqs + new_finished_seqs
-        # Sort the finished sequences by their scores.
-        all_finished_seqs.sort(
-            key=lambda x: x[0].get_beam_search_score(
-                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
-            reverse=True,
-        )
-        for seq, parent, is_new in all_finished_seqs[:beam_width]:
-            if is_new:
-                # A newly generated child sequence finishes and has a high
-                # score, so we will add it into the sequence group.
-                selected_child_seqs.append((seq, parent))
-        for seq, parent, is_new in all_finished_seqs[beam_width:]:
-            if is_new:
-                # A newly generated child sequence finishes but has a low
-                # score, so we will not add it into the sequence group.
-                # Additionally, if this sequence is a continuation of a
-                # parent sequence, we will need remove the parent sequence
-                # from the sequence group.
-                unselected_child_seqs.append((seq, parent))
-            else:
-                # An existing finished sequence has a low score, so we will
-                # remove it from the sequence group.
-                seq_group.remove(seq.seq_id)
-
-        # select the top beam_width sequences from the running
-        # sequences for the next iteration to continue the beam
-        # search.
-        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
-                              if not seq.is_finished()]
-        # Sort the running sequences by their scores.
-        running_child_seqs.sort(
-            key=lambda x: x[0].get_beam_search_score(
-                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
-            reverse=True,
-        )
-
-        # Check if we can stop the beam search.
-        if len(running_child_seqs) == 0:
-            # No running sequences, stop the beam search.
-            stop_beam_search = True
-        elif len(all_finished_seqs) < beam_width:
-            # Not enough finished sequences, continue the beam search.
-            stop_beam_search = False
-        else:
-            # Check the early stopping criteria
-            best_running_seq = running_child_seqs[0][0]
-            current_worst_seq = all_finished_seqs[beam_width - 1][0]
-            stop_beam_search = self._check_beam_search_early_stopping(
-                seq_group.sampling_params.early_stopping,
-                seq_group.sampling_params,
-                best_running_seq,
-                current_worst_seq,
-            )
-
-        if stop_beam_search:
-            # Stop the beam search and remove all the running sequences from
-            # the sequence group.
-            unselected_child_seqs.extend(running_child_seqs)
-        else:
-            # Continue the beam search and select the top beam_width sequences
-            # to continue the beam search.
-            selected_child_seqs.extend(running_child_seqs[:beam_width])
-            # The remaining running sequences will not be used in the next
-            # iteration. Again, if these sequences are continuations of
-            # parent sequences, we will need to remove the parent sequences
-            # from the sequence group.
-            unselected_child_seqs.extend(running_child_seqs[beam_width:])
-
-        # For newly created child sequences, add them to the sequence group
-        # and fork them in block manager if they are not finished.
-        for seq, parent in selected_child_seqs:
-            if seq is not parent:
-                seq_group.add(seq)
-                if not seq.is_finished():
-                    self.scheduler.fork_seq(parent, seq)
-
-        # Free the finished and selected parent sequences' memory in block
-        # manager. Keep them in the sequence group as candidate output.
-        for seq, parent in selected_child_seqs:
-            if seq is parent and seq.is_finished():
-                self.scheduler.free_seq(seq)
-
-        # Remove the unselected parent sequences from the sequence group and
-        # free their memory in block manager.
-        for seq, parent in unselected_child_seqs:
-            if seq is parent:
-                # Remove the parent sequence if it is not selected for next
-                # iteration
-                seq_group.remove(seq.seq_id)
-                self.scheduler.free_seq(seq)
-
     def _process_model_outputs(
-            self, output: SamplerOutput,
-            scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
+            self, output: List[SamplerOutput],
+            scheduled_seq_groups: List[SequenceGroup],
+            ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
+        """Apply the model output to the sequences in the scheduled seq groups.
+        
+        Returns RequestOutputs that can be returned to the client.
+        """
         now = time.time()
+        # Organize outputs by [sequence group][step] instead of
+        # [step][sequence group].
+        output_by_sequence_group = create_output_by_sequence_group(
+            sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
         # Update the scheduled sequence groups with the model outputs.
-        scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
-
-        # If prefix caching is enabled, mark all blocks in the sequence groups
-        # as completed so that future requests don't attempt to recompute them
-        if self.cache_config.context_shift:
-            for seq_group in scheduled_seq_groups:
-                self.scheduler.mark_blocks_as_computed(seq_group)
-
-        for seq_group, outputs in zip(scheduled_seq_groups, output):
-            self._process_sequence_group_outputs(seq_group, outputs)
+        for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
+                                                output_by_sequence_group):
+            seq_group = scheduled_seq_group.seq_group
+            seq_group.update_num_computed_tokens(
+                scheduled_seq_group.token_chunk_size)
+            # If all sequences in the sequence group are in DECODE, then we can
+            # process the output tokens. Otherwise, they are (chunked) prefill
+            # samples and should not be processed.
+            stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
+            if all(stage == SequenceStage.DECODE for stage in stages):
+                self.output_processor.process_outputs(seq_group, outputs)
 
         # Free the finished sequence groups.
         self.scheduler.free_finished_seq_groups()
 
         # Create the outputs.
         request_outputs: List[RequestOutput] = []
-        for seq_group in scheduled_seq_groups:
+        for scheduled_seq_group in scheduled_seq_groups:
+            seq_group = scheduled_seq_group.seq_group
             seq_group.maybe_set_first_token_time(now)
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
-        for seq_group in scheduler_outputs.ignored_seq_groups:
+        for seq_group in ignored_seq_groups:
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
 
-        # Log stats.
-        if self.log_stats:
-            self.stat_logger.log(self._get_stats(scheduler_outputs))
-
         return request_outputs
 
     def step(self) -> List[RequestOutput]:
@@ -653,22 +505,41 @@ class AphroditeEngine:
 
         if not scheduler_outputs.is_empty():
             output = self.model_executor.execute_model(
-                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
-                scheduler_outputs.blocks_to_swap_out,
-                scheduler_outputs.blocks_to_copy)
+                seq_group_metadata_list=seq_group_metadata_list,
+                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
+                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
+                blocks_to_copy=scheduler_outputs.blocks_to_copy,
+                num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
         else:
             output = []
 
-        return self._process_model_outputs(output, scheduler_outputs)
+        request_outputs = self._process_model_outputs(
+            output, scheduler_outputs.scheduled_seq_groups,
+            scheduler_outputs.ignored_seq_groups)
+
+        # Log stats.
+        if self.log_stats:
+            self.stat_logger.log(
+                self._get_stats(scheduler_outputs, model_output=output))
+
+        return request_outputs
 
     def do_log_stats(self) -> None:
         """Forced log when no requests active."""
         if self.log_stats:
             self.stat_logger.log(self._get_stats(scheduler_outputs=None))
 
-    def _get_stats(self,
-                   scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
-        """Get Stats to be Logged to Prometheus."""
+    def _get_stats(
+            self,
+            scheduler_outputs: Optional[SchedulerOutputs],
+            model_output: Optional[List[SamplerOutput]] = None) -> Stats:
+        """Get Stats to be Logged to Prometheus.
+        Args:
+            scheduler_outputs: Optional, used to populate metrics related to
+                the scheduled batch,
+            model_output: Optional, used to emit speculative decoding metrics
+                which are created by the workers.
+        """
         now = time.monotonic()
 
         # KV Cache Usage in %.
@@ -695,22 +566,25 @@ class AphroditeEngine:
         time_per_output_tokens = []
         time_e2e_requests = []
         if scheduler_outputs is not None:
-            prompt_run = scheduler_outputs.prompt_run
+            prompt_run = scheduler_outputs.num_prefill_groups > 0
 
             # Number of Tokens.
             if prompt_run:
                 num_prompt_tokens = sum(
-                    len(seq_group.prompt_token_ids)
-                    for seq_group in scheduler_outputs.scheduled_seq_groups)
+                    len(scheduled_seq_group.seq_group.prompt_token_ids)
+                    for scheduled_seq_group in
+                    scheduler_outputs.scheduled_seq_groups)
                 num_generation_tokens = sum(
-                    seq_group.num_seqs()
-                    for seq_group in scheduler_outputs.scheduled_seq_groups)
+                    scheduled_seq_group.seq_group.num_seqs()
+                    for scheduled_seq_group in
+                    scheduler_outputs.scheduled_seq_groups)
             else:
                 num_generation_tokens = scheduler_outputs.num_batched_tokens
 
             # Latency Timings.
             time_last_iters = []
-            for seq_group in scheduler_outputs.scheduled_seq_groups:
+            for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
+                seq_group = scheduled_seq_group.seq_group
                 # Time since last token.
                 # (n.b. updates seq_group.metrics.last_token_time)
                 time_last_iters.append(seq_group.get_last_latency(now))
@@ -722,6 +596,14 @@ class AphroditeEngine:
             time_to_first_tokens = time_last_iters if prompt_run else []
             time_per_output_tokens = [] if prompt_run else time_last_iters
 
+        # Spec decode, if enabled, emits specialized metrics from the worker in
+        # sampler output.
+        if model_output and (model_output[0].spec_decode_worker_metrics
+                             is not None):
+            spec_decode_metrics = model_output[0].spec_decode_worker_metrics
+        else:
+            spec_decode_metrics = None
+
         return Stats(
             now=now,
             num_running=num_running,
@@ -734,107 +616,9 @@ class AphroditeEngine:
             time_to_first_tokens=time_to_first_tokens,
             time_per_output_tokens=time_per_output_tokens,
             time_e2e_requests=time_e2e_requests,
+            spec_decode_metrics=spec_decode_metrics,
         )
 
-    def _decode_logprobs(
-        self,
-        seq: Sequence,
-        prms: SamplingParams,
-        logprobs: Dict[int, Logprob],
-        all_input_ids: List[int],
-    ) -> None:
-        if not logprobs:
-            return
-        for token_id, sample_logprob in logprobs.items():
-            if sample_logprob.decoded_token is None and token_id != -1:
-                all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
-                (
-                    _,
-                    new_text,
-                    prefix_offset,
-                    read_offset,
-                ) = detokenize_incrementally(
-                    self.get_tokenizer_for_seq(seq),
-                    all_input_ids=all_input_ids_with_logprob,
-                    prev_tokens=seq.tokens,
-                    prefix_offset=seq.prefix_offset,
-                    read_offset=seq.read_offset,
-                    skip_special_tokens=prms.skip_special_tokens,
-                    spaces_between_special_tokens=prms.
-                    spaces_between_special_tokens,
-                )
-                sample_logprob.decoded_token = new_text
-
-    def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
-        """Decodes the new token for a sequence."""
-        all_input_ids = seq.get_token_ids()
-        self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
-                              all_input_ids)
-
-        (
-            new_tokens,
-            new_output_text,
-            prefix_offset,
-            read_offset,
-        ) = detokenize_incrementally(
-            self.get_tokenizer_for_seq(seq),
-            all_input_ids=all_input_ids,
-            prev_tokens=seq.tokens,
-            prefix_offset=seq.prefix_offset,
-            read_offset=seq.read_offset,
-            skip_special_tokens=prms.skip_special_tokens,
-            spaces_between_special_tokens=prms.spaces_between_special_tokens,
-        )
-        if seq.tokens is None:
-            seq.tokens = new_tokens
-        else:
-            seq.tokens.extend(new_tokens)
-        seq.prefix_offset = prefix_offset
-        seq.read_offset = read_offset
-        seq.output_text += new_output_text
-
-    def _check_stop(self, seq: Sequence,
-                    sampling_params: SamplingParams) -> None:
-        """Stop the finished sequences."""
-        for stop_str in sampling_params.stop:
-            if seq.output_text.endswith(stop_str):
-                self._finalize_sequence(seq, sampling_params, stop_str)
-                seq.status = SequenceStatus.FINISHED_STOPPED
-                return
-        if seq.get_last_token_id() in sampling_params.stop_token_ids:
-            stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
-                seq.get_last_token_id())
-            self._finalize_sequence(seq, sampling_params, stop_str)
-            seq.status = SequenceStatus.FINISHED_STOPPED
-            return
-
-        # Check if the sequence has reached max_model_len.
-        if seq.get_len() > self.scheduler_config.max_model_len:
-            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
-            return
-
-        # Check if the sequence has reached max_tokens.
-        if seq.get_output_len() == sampling_params.max_tokens:
-            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
-            return
-
-        # Check if the sequence has generated the EOS token.
-        if (not sampling_params.ignore_eos
-            ) and seq.get_last_token_id() == seq.eos_token_id:
-            seq.status = SequenceStatus.FINISHED_STOPPED
-            return
-
-    def _finalize_sequence(self, seq: Sequence,
-                           sampling_params: SamplingParams,
-                           stop_string: str) -> None:
-        if sampling_params.include_stop_str_in_output:
-            return
-
-        if stop_string and seq.output_text.endswith(stop_string):
-            # Truncate the output text so that the stop string is
-            # not included in the output.
-            seq.output_text = seq.output_text[:-len(stop_string)]
-
     def add_lora(self, lora_request: LoRARequest) -> bool:
         return self.model_executor.add_lora(lora_request)
 

+ 188 - 41
aphrodite/engine/args_tools.py

@@ -1,17 +1,14 @@
 import argparse
 import dataclasses
 from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional
 
-from aphrodite.common.config import (
-    CacheConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-    DeviceConfig,
-    TokenizerPoolConfig,
-)
+from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
+                                     EngineConfig, LoRAConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     SpeculativeConfig, TokenizerPoolConfig,
+                                     VisionLanguageConfig)
+from aphrodite.common.utils import str_to_int_tuple
 
 
 @dataclass
@@ -26,7 +23,7 @@ class EngineArgs:
     load_format: str = "auto"
     dtype: str = "auto"
     kv_cache_dtype: str = "auto"
-    # kv_quant_params_path: str = None
+    quantization_param_path: Optional[str] = None
     seed: int = 0
     max_model_len: Optional[int] = None
     worker_use_ray: bool = False
@@ -35,6 +32,7 @@ class EngineArgs:
     max_parallel_loading_workers: Optional[int] = None
     block_size: int = 16
     context_shift: bool = False
+    use_v2_block_manager: bool = False
     swap_space: int = 4  # GiB
     gpu_memory_utilization: float = 0.90
     max_num_batched_tokens: Optional[int] = None
@@ -62,6 +60,22 @@ class EngineArgs:
     max_cpu_loras: Optional[int] = None
     device: str = "auto"
     ray_workers_use_nsight: bool = False
+    num_gpu_blocks_override: Optional[int] = None
+    num_lookahead_slots: int = 0
+    # Related to Vision-language models such as llava
+    image_input_type: Optional[str] = None
+    image_token_id: Optional[int] = None
+    image_input_shape: Optional[str] = None
+    image_feature_size: Optional[int] = None
+    scheduler_delay_factor: float = 0.0
+    enable_chunked_prefill: bool = False
+    guided_decoding_backend: str = 'outlines'
+    # Speculative decoding config
+    speculative_model: Optional[str] = None
+    num_speculative_tokens: Optional[int] = None
+    speculative_max_model_len: Optional[int] = None
+    ngram_prompt_lookup_max: Optional[int] = None
+    ngram_prompt_lookup_min: Optional[int] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -163,23 +177,25 @@ class EngineArgs:
             "for BF16 models.",
         )
         parser.add_argument(
-            "--kv-cache-dtype",
+            '--kv-cache-dtype',
             type=str,
-            # choices=["auto", "fp8_e5m2", "int8"],
-            choices=['auto', 'fp8_e5m2'],
+            choices=['auto', 'fp8'],
             default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
-            "data type. Note FP8 is not supported when cuda version is "
-            "lower than 11.8.",
-        )
-        # parser.add_argument(
-        #     "--kv-quant-params-path",
-        #     type=str,
-        #     default=EngineArgs.kv_quant_params_path,
-        #     help="Path to scales and zero points of KV cache "
-        #     "quantization. Only applicable when kv-cache-dtype "
-        #     "is int8.",
-        # )
+            'data type. FP8_E5M2 (without scaling) is only supported on cuda '
+            'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
+            'supported for common inference criteria. ')
+        parser.add_argument(
+            '--quantization-param-path',
+            type=str,
+            default=None,
+            help='Path to the JSON file containing the KV cache '
+            'scaling factors. This should generally be supplied, when '
+            'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
+            'default to 1.0, which may cause accuracy issues. '
+            'FP8_E5M2 (without scaling) is only supported on cuda version'
+            'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
+            'supported for common inference criteria. ')
         parser.add_argument(
             "--max-model-len",
             type=int,
@@ -187,6 +203,13 @@ class EngineArgs:
             help="model context length. If unspecified, "
             "will be automatically derived from the model.",
         )
+        parser.add_argument(
+            '--guided-decoding-backend',
+            type=str,
+            default='outlines',
+            choices=['outlines', 'lm-format-enforcer'],
+            help='Which engine will be used for guided decoding'
+            ' (JSON schema / regex etc)')
         # Parallel arguments
         parser.add_argument(
             "--worker-use-ray",
@@ -226,7 +249,7 @@ class EngineArgs:
             "--block-size",
             type=int,
             default=EngineArgs.block_size,
-            choices=[8, 16, 32, 128],
+            choices=[8, 16, 32],
             help="token block size",
         )
         parser.add_argument(
@@ -234,6 +257,17 @@ class EngineArgs:
             action="store_true",
             help="Enable context shifting.",
         )
+        parser.add_argument("--use-v2-block-manager",
+                            action="store_true",
+                            help="Use the v2 block manager.")
+        parser.add_argument(
+            "--num-lookahead-slots",
+            type=int,
+            default=EngineArgs.num_lookahead_slots,
+            help="Experimental scheduling config necessary for "
+            "speculative decoding. This will be replaced by "
+            "speculative decoding config in the future; it is "
+            "present for testing purposes until then.")
         parser.add_argument("--seed",
                             type=int,
                             default=EngineArgs.seed,
@@ -253,6 +287,12 @@ class EngineArgs:
             "the model executor, which can range from 0 to 1."
             "If unspecified, will use the default value of 0.9.",
         )
+        parser.add_argument(
+            "--num-gpu-blocks-override",
+            type=int,
+            default=None,
+            help="If specified, ignore GPU profiling result and use this "
+            "number of GPU blocks. Used for testing preemption.")
         parser.add_argument(
             "--max-num-batched-tokens",
             type=int,
@@ -287,6 +327,7 @@ class EngineArgs:
                 "aqlm",
                 "awq",
                 "bnb",
+                "eetq",
                 "exl2",
                 "gguf",
                 "gptq",
@@ -409,10 +450,80 @@ class EngineArgs:
             "--device",
             type=str,
             default=EngineArgs.device,
-            choices=["cuda"],
-            help=("Device to use for model execution. "
-                  'Currently, only "cuda" is supported.'),
+            choices=["auto", "cuda", "neuron", "cpu"],
+            help=("Device to use for model execution."),
         )
+        # Related to Vision-language models such as llava
+        parser.add_argument(
+            "--image-input-type",
+            type=str,
+            default=None,
+            choices=[
+                t.name.lower() for t in VisionLanguageConfig.ImageInputType
+            ],
+            help=("The image input type passed into Aphrodite. "
+                  "Should be one of `pixel_values` or `image_features`"))
+        parser.add_argument("--image-token-id",
+                            type=int,
+                            default=None,
+                            help=("Input id for image token."))
+        parser.add_argument(
+            '--image-input-shape',
+            type=str,
+            default=None,
+            help=(
+                'The biggest image input shape (worst for memory footprint) '
+                'given an input type. Only used for Aphrodite\'s profile_run.'
+            ))
+        parser.add_argument(
+            '--image-feature-size',
+            type=int,
+            default=None,
+            help=('The image feature size along the context dimension.'))
+        parser.add_argument(
+            "--scheduler-delay-factor",
+            "-sdf",
+            type=float,
+            default=EngineArgs.scheduler_delay_factor,
+            help="Apply a delay (of delay factor multiplied by previous "
+            "prompt latency) before scheduling next prompt.")
+        parser.add_argument(
+            "--enable-chunked-prefill",
+            action="store_true",
+            help="If True, the prefill requests can be chunked based on the "
+            "max_num_batched_tokens.")
+        parser.add_argument(
+            "--speculative-model",
+            type=str,
+            default=EngineArgs.speculative_model,
+            help=
+            "The name of the draft model to be used in speculative decoding.")
+        parser.add_argument(
+            "--num-speculative-tokens",
+            type=int,
+            default=EngineArgs.num_speculative_tokens,
+            help="The number of speculative tokens to sample from "
+            "the draft model in speculative decoding")
+        parser.add_argument(
+            "--speculative-max-model-len",
+            type=str,
+            default=EngineArgs.speculative_max_model_len,
+            help="The maximum sequence length supported by the "
+            "draft model. Sequences over this length will skip "
+            "speculation.")
+        parser.add_argument(
+            "--ngram-prompt-lookup-max",
+            type=int,
+            default=None,
+            help='Max size of window for ngram prompt lookup in speculative '
+            'decoding.')
+
+        parser.add_argument(
+            '--ngram-prompt-lookup-min',
+            type=int,
+            default=None,
+            help='Min size of window for ngram prompt lookup in speculative '
+            'decoding.')
         return parser
 
     @classmethod
@@ -423,10 +534,7 @@ class EngineArgs:
         engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
         return engine_args
 
-    def create_engine_configs(
-        self,
-    ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
-               DeviceConfig, Optional[LoRAConfig], ]:
+    def create_engine_config(self, ) -> EngineConfig:
         device_config = DeviceConfig(self.device)
         model_config = ModelConfig(
             self.model,
@@ -445,6 +553,7 @@ class EngineArgs:
             self.load_in_4bit,
             self.load_in_8bit,
             self.load_in_smooth,
+            self.quantization_param_path,
             self.enforce_eager,
             self.max_context_len_to_capture,
             self.max_log_probs,
@@ -455,6 +564,7 @@ class EngineArgs:
             self.swap_space,
             self.kv_cache_dtype,
             # self.kv_quant_params_path,
+            self.num_gpu_blocks_override,
             model_config.get_sliding_window(),
             self.context_shift,
         )
@@ -471,10 +581,28 @@ class EngineArgs:
             ),
             self.ray_workers_use_nsight,
         )
+        speculative_config = SpeculativeConfig.maybe_create_spec_config(
+            target_model_config=model_config,
+            target_parallel_config=parallel_config,
+            target_dtype=self.dtype,
+            speculative_model=self.speculative_model,
+            num_speculative_tokens=self.num_speculative_tokens,
+            speculative_max_model_len=self.speculative_max_model_len,
+            enable_chunked_prefill=self.enable_chunked_prefill,
+            use_v2_block_manager=self.use_v2_block_manager,
+            ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
+            ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
+        )
         scheduler_config = SchedulerConfig(
             self.max_num_batched_tokens,
             self.max_num_seqs,
             model_config.max_model_len,
+            self.use_v2_block_manager,
+            num_lookahead_slots=(self.num_lookahead_slots
+                                 if speculative_config is None else
+                                 speculative_config.num_lookahead_slots),
+            delay_factor=self.scheduler_delay_factor,
+            enable_chunked_prefill=self.enable_chunked_prefill,
         )
         lora_config = (LoRAConfig(
             max_lora_rank=self.max_lora_rank,
@@ -484,14 +612,33 @@ class EngineArgs:
             max_cpu_loras=self.max_cpu_loras
             if self.max_cpu_loras and self.max_cpu_loras > 0 else None,
         ) if self.enable_lora else None)
-        return (
-            model_config,
-            cache_config,
-            parallel_config,
-            scheduler_config,
-            device_config,
-            lora_config,
-        )
+        if self.image_input_type:
+            if (not self.image_token_id or not self.image_input_shape
+                    or not self.image_feature_size):
+                raise ValueError(
+                    "Specify `image_token_id`, `image_input_shape` and "
+                    "`image_feature_size` together with `image_input_type`.")
+            vision_language_config = VisionLanguageConfig(
+                image_input_type=VisionLanguageConfig.
+                get_image_input_enum_type(self.image_input_type),
+                image_token_id=self.image_token_id,
+                image_input_shape=str_to_int_tuple(self.image_input_shape),
+                image_feature_size=self.image_feature_size,
+            )
+        else:
+            vision_language_config = None
+
+        decoding_config = DecodingConfig(
+            guided_decoding_backend=self.guided_decoding_backend)
+        return EngineConfig(model_config=model_config,
+                            cache_config=cache_config,
+                            parallel_config=parallel_config,
+                            scheduler_config=scheduler_config,
+                            device_config=device_config,
+                            lora_config=lora_config,
+                            vision_language_config=vision_language_config,
+                            speculative_config=speculative_config,
+                            decoding_config=decoding_config)
 
 
 @dataclass

+ 36 - 20
aphrodite/engine/async_aphrodite.py

@@ -14,6 +14,7 @@ from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import MultiModalData
 
 ENGINE_ITERATION_TIMEOUT_S = int(
     os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "120"))
@@ -214,11 +215,20 @@ class _AsyncAphrodite(AphroditeEngine):
             output = await self.model_executor.execute_model_async(
                 seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
                 scheduler_outputs.blocks_to_swap_out,
-                scheduler_outputs.blocks_to_copy)
+                scheduler_outputs.blocks_to_copy,
+                scheduler_outputs.num_lookahead_slots)
         else:
             output = []
 
-        return self._process_model_outputs(output, scheduler_outputs)
+        request_outputs = self._process_model_outputs(
+            output, scheduler_outputs.scheduled_seq_groups,
+            scheduler_outputs.ignored_seq_groups)
+
+        # Log stats.
+        if self.log_stats:
+            self.stat_logger.log(self._get_stats(scheduler_outputs))
+
+        return request_outputs
 
     async def encode_request_async(
         self,
@@ -243,6 +253,7 @@ class _AsyncAphrodite(AphroditeEngine):
         prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -255,14 +266,13 @@ class _AsyncAphrodite(AphroditeEngine):
             prompt_token_ids=prompt_token_ids,
             lora_request=lora_request)
 
-        return self.add_request(
-            request_id,
-            prompt=prompt,
-            prompt_token_ids=prompt_token_ids,
-            sampling_params=sampling_params,
-            arrival_time=arrival_time,
-            lora_request=lora_request,
-        )
+        return self.add_request(request_id,
+                                prompt=prompt,
+                                prompt_token_ids=prompt_token_ids,
+                                sampling_params=sampling_params,
+                                arrival_time=arrival_time,
+                                lora_request=lora_request,
+                                multi_modal_data=multi_modal_data)
 
     async def check_health_async(self) -> None:
         self.model_executor.check_health()
@@ -327,22 +337,28 @@ class AsyncAphrodite:
                          start_engine_loop: bool = True) -> "AsyncAphrodite":
         """Creates an async LLM engine from the engine arguments."""
         # Create the engine configs.
-        engine_configs = engine_args.create_engine_configs()
-        parallel_config = engine_configs[2]
-        if parallel_config.worker_use_ray or engine_args.engine_use_ray:
-            initialize_ray_cluster(parallel_config)
+        engine_config = engine_args.create_engine_config()
+
+        if engine_config.device_config.device_type == "neuron":
+            raise NotImplementedError("Neuron is not supported for "
+                                      "async engine yet.")
+        elif engine_config.device_config.device_type == "cpu":
+            from aphrodite.executor.cpu_executor import CPUExecutor
+            executor_class = CPUExecutor
+        elif engine_config.parallel_config.worker_use_ray:
+            initialize_ray_cluster(engine_config.parallel_config)
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
             executor_class = RayGPUExecutorAsync
         else:
-            assert parallel_config.world_size == 1, (
+            assert engine_config.parallel_config.world_size == 1, (
                 "Ray is required if parallel_config.world_size > 1.")
             from aphrodite.executor.gpu_executor import GPUExecutorAsync
             executor_class = GPUExecutorAsync
         # Create the async LLM engine.
-        engine = cls(parallel_config.worker_use_ray,
+        engine = cls(engine_config.parallel_config.worker_use_ray,
                      engine_args.engine_use_ray,
-                     *engine_configs,
-                     executor_class,
+                     **engine_config.to_dict(),
+                     executor_class=executor_class,
                      log_requests=not engine_args.disable_log_requests,
                      log_stats=not engine_args.disable_log_stats,
                      max_log_len=engine_args.max_log_len,
@@ -402,8 +418,8 @@ class AsyncAphrodite:
         else:
             # FIXME: This is a bit hacky. Be careful when changing the
             # order of the arguments.
-            cache_config = args[1]
-            parallel_config = args[2]
+            cache_config = kwargs["cache_config"]
+            parallel_config = kwargs["parallel_config"]
             if parallel_config.tensor_parallel_size == 1:
                 num_gpus = cache_config.gpu_memory_utilization
             else:

+ 27 - 12
aphrodite/engine/metrics.py

@@ -1,17 +1,14 @@
-from loguru import logger
-from prometheus_client import (
-    Counter,
-    Gauge,
-    Histogram,
-    Info,
-    REGISTRY,
-    disable_created_metrics,
-)
-
 import time
-import numpy as np
-from typing import Dict, List
 from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+import numpy as np
+from loguru import logger
+from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
+                               disable_created_metrics)
+
+if TYPE_CHECKING:
+    from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 disable_created_metrics()
 
@@ -158,6 +155,8 @@ class Stats:
     time_per_output_tokens: List[float]
     time_e2e_requests: List[float]
 
+    spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
 
 class StatLogger:
     """StatLogger is used AphroditeEngine to log to Promethus and Stdout."""
@@ -270,3 +269,19 @@ class StatLogger:
             self.num_prompt_tokens = []
             self.num_generation_tokens = []
             self.last_local_log = stats.now
+
+            if stats.spec_decode_metrics is not None:
+                logger.info(
+                    self._format_spec_decode_metrics_str(
+                        stats.spec_decode_metrics))
+
+    def _format_spec_decode_metrics_str(
+            self, metrics: "SpecDecodeWorkerMetrics") -> str:
+
+        return ("Speculative metrics: "
+                f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
+                f"System efficiency: {metrics.system_efficiency:.3f}, "
+                f"Number of speculative tokens: {metrics.num_spec_tokens}, "
+                f"Number of accepted tokens: {metrics.accepted_tokens}, "
+                f"Number of draft tokens tokens: {metrics.draft_tokens}, "
+                f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")

+ 0 - 0
aphrodite/engine/output_processor/__init__.py


+ 68 - 0
aphrodite/engine/output_processor/interfaces.py

@@ -0,0 +1,68 @@
+from abc import ABC, abstractmethod
+from typing import Callable, Iterable, List
+
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.config import SchedulerConfig
+from aphrodite.common.sequence import (Sequence, SequenceGroup,
+                                       SequenceGroupOutput)
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.processing.scheduler import Scheduler
+from aphrodite.transformers_utils.detokenizer import Detokenizer
+
+
+class SequenceGroupOutputProcessor(ABC):
+    """Interface for logic that processes new token ids in sequence groups,
+    managing detokenization, stop checking, and freeing/forking sequences with
+    the scheduler.
+    This is highly coupled with the LLMEngine and should be seen as an extension
+    of it. The logic is separated to simplify the LLMEngine class and allow
+    separate implementations for single-step decoding (which supports beam
+    search sequence forking) and multi-step decoding (which does not support
+    beam search, but does support speculative decoding).
+    """
+
+    @staticmethod
+    def create_output_processor(
+        scheduler_config: SchedulerConfig,
+        detokenizer: Detokenizer,
+        scheduler: Scheduler,
+        seq_counter: Iterable[int],
+        get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
+        stop_checker: "StopChecker",
+    ):
+        """Create an output processor.
+        This returns a single-step output processor if num_lookahead_slots is
+        zero, else returns a multi-step output processor.
+        """
+        if scheduler_config.num_lookahead_slots == 0:
+            # Importing here to avoid cycle.
+            from aphrodite.engine.output_processor.single_step import \
+                SingleStepOutputProcessor
+            return SingleStepOutputProcessor(
+                scheduler_config,
+                detokenizer,
+                scheduler,
+                seq_counter,
+                stop_checker,
+            )
+        else:
+            # Importing here to avoid cycle.
+            from aphrodite.engine.output_processor.multi_step import \
+                MultiStepOutputProcessor
+            return MultiStepOutputProcessor(
+                detokenizer,
+                scheduler,
+                seq_counter,
+                get_tokenizer_for_seq,
+                stop_checker,
+            )
+
+    @abstractmethod
+    def process_outputs(self, sequence_group: SequenceGroup,
+                        outputs: List[SequenceGroupOutput]) -> None:
+        """Process new token ids for the sequence group. Handles logic such as
+        detokenization, stop checking, and freeing/forking sequences in the
+        scheduler.
+        """
+        pass

+ 122 - 0
aphrodite/engine/output_processor/multi_step.py

@@ -0,0 +1,122 @@
+from typing import Callable, Iterable, List
+
+from transformers import PreTrainedTokenizer
+
+from aphrodite.processing.scheduler import Scheduler
+from aphrodite.engine.output_processor.interfaces import (
+    SequenceGroupOutputProcessor)
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
+                                       SequenceGroupOutput, SequenceOutput,
+                                       SequenceStatus)
+from aphrodite.transformers_utils.detokenizer import Detokenizer
+
+
+class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
+    """SequenceGroupOutputProcessor which handles logic related to
+    detokenization and stopping conditions. It specializes to "multi-step
+    decoding", where Aphrodite's worker may generate multiple tokens per
+    invocation.
+    This is currently mutually exclusive with advanced sampling techniques like
+    beam search, which motivates the separation of this logic from the single
+    step output processor.
+    This class is responsible for things such as correctly appending all new
+    token ids to their sequence, detokenizing new token ids, truncating new
+    output tokens after an eos token, and correctly handling the case where the
+    number of new output tokens per sequence differs in a single batch.
+    """
+
+    def __init__(
+        self,
+        detokenizer: Detokenizer,
+        scheduler: Scheduler,
+        seq_counter: Iterable[int],
+        get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
+        stop_checker: StopChecker,
+    ):
+        self.detokenizer = detokenizer
+        self.scheduler = scheduler
+        self.seq_counter = seq_counter
+        self.get_tokenizer_for_seq = get_tokenizer_for_seq
+        self.stop_checker = stop_checker
+
+    def process_outputs(self, sequence_group: SequenceGroup,
+                        outputs: List[SequenceGroupOutput]) -> None:
+        """Append new tokens in the outputs to sequences in the sequence group.
+        This only supports sequence groups of size 1. It supports greater than
+        one new token per sequence.
+        This applies logic like stop condition checking and detokenization,
+        including freeing finished sequences. It also handles cases where there
+        are tokens emitted after the EOS token.
+        """
+        seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
+
+        assert seqs, "expected running sequences"
+        assert len(seqs) == 1, (
+            "Beam search not supported in multi-step decoding.")
+        seq = seqs[0]
+
+        # Since there's only one sequence per sequence group, we can take the
+        # first sample.
+        samples = [outputs[step].samples[0] for step in range(len(outputs))]
+
+        # -1 means the output token is not valid (eg. due to spec decode
+        # rejecting tokens).
+        valid_samples = [
+            sample for sample in samples if sample.output_token != -1
+        ]
+        assert valid_samples
+
+        self._process_seq_outputs(seq, valid_samples,
+                                  sequence_group.sampling_params)
+
+    def _process_seq_outputs(self, seq: Sequence,
+                             valid_samples: List[SequenceOutput],
+                             sampling_params: SamplingParams) -> None:
+        output_token_ids = [sample.output_token for sample in valid_samples]
+
+        # Truncate to max_tokens if necessary.
+        remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
+                                                         len(output_token_ids))
+        if remaining_tokens < 0:
+            valid_samples = valid_samples[:remaining_tokens]
+            output_token_ids = output_token_ids[:remaining_tokens]
+
+        # Truncate any tokens after EOS. This is required as spec decode
+        # generates a fixed number of tokens without evaluating stopping
+        # conditions within the block. This can cause an eos token to be
+        # unintentionally ignored.
+        if not sampling_params.ignore_eos:
+            eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
+            # Avoiding .index calls as exception throwing in the happy path
+            # is expensive.
+            for i in range(len(output_token_ids)):
+                if output_token_ids[i] == eos_token_id:
+                    output_token_ids = output_token_ids[:i + 1]
+                    valid_samples = valid_samples[:i + 1]
+                    break
+
+        # Incrementally append tokens to the sequence, as if we had only one new
+        # token.
+        for output_token_id in output_token_ids:
+            seq.append_token_id(
+                token_id=output_token_id,
+                # TODO emit logprobs in multi-step decoding.
+                logprobs={output_token_id: Logprob(0.0)},
+            )
+
+            new_char_count = 0
+            if sampling_params.detokenize:
+                new_char_count = self.detokenizer.decode_sequence_inplace(
+                    seq, sampling_params)
+
+            self.stop_checker.maybe_stop_sequence(
+                seq,
+                new_char_count=new_char_count,
+                sampling_params=sampling_params)
+            if seq.is_finished():
+                break
+
+        if seq.is_finished():
+            self.scheduler.free_seq(seq)

+ 272 - 0
aphrodite/engine/output_processor/single_step.py

@@ -0,0 +1,272 @@
+from typing import Iterable, List, Tuple, Union
+
+from aphrodite.common.config import SchedulerConfig
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import (Sequence, SequenceGroup,
+                                       SequenceGroupOutput, SequenceOutput,
+                                       SequenceStatus)
+from aphrodite.engine.output_processor.interfaces import \
+    SequenceGroupOutputProcessor
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.processing.scheduler import Scheduler
+from aphrodite.transformers_utils.detokenizer import Detokenizer
+
+
+class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
+    """SequenceGroupOutputProcessor which handles "output processing" logic,
+    which happens after the model returns generated token ids and before
+    scheduling of the next batch. Output processing logic includes
+    detokenization, and determining if a sequence is finished (e.g. via max len
+    or eos token).
+    The SingleStepOutputProcessor is specialized to the case where the model
+    emits at most a single token per invocation, which precludes configurations
+    such as speculative decoding or multi-step decoding. This enables beam
+    search sampling, which requires forking/finishing/freeing sequences in a way
+    that is currently difficult to schedule multiple steps ahead of time.
+    """
+
+    def __init__(
+        self,
+        scheduler_config: SchedulerConfig,
+        detokenizer: Detokenizer,
+        scheduler: Scheduler,
+        seq_counter: Iterable[int],
+        stop_checker: StopChecker,
+    ):
+        self.scheduler_config = scheduler_config
+        self.detokenizer = detokenizer
+        self.scheduler = scheduler
+        self.seq_counter = seq_counter
+        self.stop_checker = stop_checker
+
+    def process_outputs(self, sequence_group: SequenceGroup,
+                        outputs: List[SequenceGroupOutput]) -> None:
+        """Append all new tokens to sequences in the sequence group. Fork any
+        surviving beam candidates; free any unsurviving ones.
+        Invokes detokenizer to detokenize new tokens, and also marks sequences
+        as finished if they meet stop conditions.
+        """
+        assert (len(outputs) == 1
+                ), f"{type(self)} does not support multiple outputs per step"
+        return self._process_sequence_group_outputs(sequence_group, outputs[0])
+
+    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
+                                        outputs: SequenceGroupOutput) -> None:
+
+        # Process prompt logprobs
+        prompt_logprobs = outputs.prompt_logprobs
+        if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
+            self.detokenizer.decode_prompt_logprobs_inplace(
+                seq_group, prompt_logprobs)
+            seq_group.prompt_logprobs = prompt_logprobs
+
+        # Process samples
+        samples = outputs.samples
+        parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
+        existing_finished_seqs = seq_group.get_finished_seqs()
+        parent_child_dict = {
+            parent_seq.seq_id: []
+            for parent_seq in parent_seqs
+        }
+        for sample in samples:
+            parent_child_dict[sample.parent_seq_id].append(sample)
+        # List of (child, parent)
+        child_seqs: List[Tuple[Sequence, Sequence]] = []
+
+        # Process the child samples for each parent sequence
+        for parent in parent_seqs:
+            child_samples: List[SequenceOutput] = parent_child_dict[
+                parent.seq_id]
+            if len(child_samples) == 0:
+                # This parent sequence has no children samples. Remove
+                # the parent sequence from the sequence group since it will
+                # not be used in the future iterations.
+                parent.status = SequenceStatus.FINISHED_ABORTED
+                seq_group.remove(parent.seq_id)
+                self.scheduler.free_seq(parent)
+                continue
+            # Fork the parent sequence if there are multiple child samples.
+            for child_sample in child_samples[:-1]:
+                new_child_seq_id = next(self.seq_counter)
+                child = parent.fork(new_child_seq_id)
+                child.append_token_id(child_sample.output_token,
+                                      child_sample.logprobs)
+                child_seqs.append((child, parent))
+            # Continue the parent sequence for the last child sample.
+            # We reuse the parent sequence here to reduce redundant memory
+            # copies, especially when using non-beam search sampling methods.
+            last_child_sample = child_samples[-1]
+            parent.append_token_id(last_child_sample.output_token,
+                                   last_child_sample.logprobs)
+            child_seqs.append((parent, parent))
+
+        for seq, _ in child_seqs:
+            if seq_group.sampling_params.detokenize:
+                new_char_count = self.detokenizer.decode_sequence_inplace(
+                    seq, seq_group.sampling_params)
+            else:
+                new_char_count = 0
+            self.stop_checker.maybe_stop_sequence(seq, new_char_count,
+                                                  seq_group.sampling_params)
+
+        # Non-beam search case
+        if not seq_group.sampling_params.use_beam_search:
+            # For newly created child sequences, add them to the sequence group
+            # and fork them in block manager if they are not finished.
+            for seq, parent in child_seqs:
+                if seq is not parent:
+                    seq_group.add(seq)
+                    if not seq.is_finished():
+                        self.scheduler.fork_seq(parent, seq)
+
+            # Free the finished and selected parent sequences' memory in block
+            # manager. Keep them in the sequence group as candidate output.
+            # NOTE: we need to fork the new sequences before freeing the
+            # old sequences.
+            for seq, parent in child_seqs:
+                if seq is parent and seq.is_finished():
+                    self.scheduler.free_seq(seq)
+            return
+
+        # Beam search case
+        # Select the child sequences to keep in the sequence group.
+        selected_child_seqs = []
+        unselected_child_seqs = []
+        beam_width = seq_group.sampling_params.best_of
+        length_penalty = seq_group.sampling_params.length_penalty
+
+        # Select the newly finished sequences with the highest scores
+        # to replace existing finished sequences.
+        # Tuple of (seq, parent, is_new)
+        existing_finished_seqs = [(seq, None, False)
+                                  for seq in existing_finished_seqs]
+        new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
+                             if seq.is_finished()]
+        all_finished_seqs = existing_finished_seqs + new_finished_seqs
+        # Sort the finished sequences by their scores.
+        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
+            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+                               reverse=True)
+        for seq, parent, is_new in all_finished_seqs[:beam_width]:
+            if is_new:
+                # A newly generated child sequence finishes and has a high
+                # score, so we will add it into the sequence group.
+                selected_child_seqs.append((seq, parent))
+        for seq, parent, is_new in all_finished_seqs[beam_width:]:
+            if is_new:
+                # A newly generated child sequence finishes but has a low
+                # score, so we will not add it into the sequence group.
+                # Additionally, if this sequence is a continuation of a
+                # parent sequence, we will need remove the parent sequence
+                # from the sequence group.
+                unselected_child_seqs.append((seq, parent))
+            else:
+                # An existing finished sequence has a low score, so we will
+                # remove it from the sequence group.
+                seq_group.remove(seq.seq_id)
+
+        # select the top beam_width sequences from the running
+        # sequences for the next iteration to continue the beam
+        # search.
+        running_child_seqs = [(seq, parent) for seq, parent in child_seqs
+                              if not seq.is_finished()]
+        # Sort the running sequences by their scores.
+        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
+            length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+                                reverse=True)
+
+        # Check if we can stop the beam search.
+        if len(running_child_seqs) == 0:
+            # No running sequences, stop the beam search.
+            stop_beam_search = True
+        elif len(all_finished_seqs) < beam_width:
+            # Not enough finished sequences, continue the beam search.
+            stop_beam_search = False
+        else:
+            # Check the early stopping criteria
+            best_running_seq = running_child_seqs[0][0]
+            current_worst_seq = all_finished_seqs[beam_width - 1][0]
+            stop_beam_search = self._check_beam_search_early_stopping(
+                seq_group.sampling_params.early_stopping,
+                seq_group.sampling_params, best_running_seq, current_worst_seq)
+
+        if stop_beam_search:
+            # Stop the beam search and remove all the running sequences from
+            # the sequence group.
+            unselected_child_seqs.extend(running_child_seqs)
+        else:
+            # Continue the beam search and select the top beam_width sequences
+            # to continue the beam search.
+            selected_child_seqs.extend(running_child_seqs[:beam_width])
+            # The remaining running sequences will not be used in the next
+            # iteration. Again, if these sequences are continuations of
+            # parent sequences, we will need to remove the parent sequences
+            # from the sequence group.
+            unselected_child_seqs.extend(running_child_seqs[beam_width:])
+
+        # For newly created child sequences, add them to the sequence group
+        # and fork them in block manager if they are not finished.
+        for seq, parent in selected_child_seqs:
+            if seq is not parent:
+                seq_group.add(seq)
+                if not seq.is_finished():
+                    self.scheduler.fork_seq(parent, seq)
+
+        # Free the finished and selected parent sequences' memory in block
+        # manager. Keep them in the sequence group as candidate output.
+        for seq, parent in selected_child_seqs:
+            if seq is parent and seq.is_finished():
+                self.scheduler.free_seq(seq)
+
+        # Remove the unselected parent sequences from the sequence group and
+        # free their memory in block manager.
+        for seq, parent in unselected_child_seqs:
+            if seq is parent:
+                # Remove the parent sequence if it is not selected for next
+                # iteration
+                seq_group.remove(seq.seq_id)
+                self.scheduler.free_seq(seq)
+
+    def _check_beam_search_early_stopping(
+        self,
+        early_stopping: Union[bool, str],
+        sampling_params: SamplingParams,
+        best_running_seq: Sequence,
+        current_worst_seq: Sequence,
+    ) -> bool:
+        assert sampling_params.use_beam_search
+        length_penalty = sampling_params.length_penalty
+        if early_stopping is True:
+            return True
+
+        current_worst_score = current_worst_seq.get_beam_search_score(
+            length_penalty=length_penalty,
+            eos_token_id=current_worst_seq.eos_token_id)
+        if early_stopping is False:
+            highest_attainable_score = best_running_seq.get_beam_search_score(
+                length_penalty=length_penalty,
+                eos_token_id=best_running_seq.eos_token_id)
+        else:
+            assert early_stopping == "never"
+            if length_penalty > 0.0:
+                # If length_penalty > 0.0, beam search will prefer longer
+                # sequences. The highest attainable score calculation is
+                # based on the longest possible sequence length in this case.
+                max_possible_length = max(
+                    best_running_seq.get_prompt_len() +
+                    sampling_params.max_tokens,
+                    self.scheduler_config.max_model_len)
+                highest_attainable_score = (
+                    best_running_seq.get_beam_search_score(
+                        length_penalty=length_penalty,
+                        eos_token_id=best_running_seq.eos_token_id,
+                        seq_len=max_possible_length))
+            else:
+                # Otherwise, beam search will prefer shorter sequences. The
+                # highest attainable score calculation is based on the current
+                # sequence length.
+                highest_attainable_score = (
+                    best_running_seq.get_beam_search_score(
+                        length_penalty=length_penalty,
+                        eos_token_id=best_running_seq.eos_token_id))
+        return current_worst_score >= highest_attainable_score

+ 99 - 0
aphrodite/engine/output_processor/stop_checker.py

@@ -0,0 +1,99 @@
+from typing import Callable, Optional
+
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import Sequence, SequenceStatus
+
+
+class StopChecker:
+    """AphroditeEngine helper class which separates out the logic involving
+    stop checking. This checks things such as: whether the eos token was
+    emitted, whether the max_tokens has been consumed, whether a stop string
+    has been emitted, or if we have exceeded the max model len.
+    """
+
+    def __init__(self, max_model_len: int,
+                 get_tokenizer_for_seq: Callable[[Sequence],
+                                                 PreTrainedTokenizer]):
+        self.max_model_len = max_model_len
+        self.get_tokenizer_for_seq = get_tokenizer_for_seq
+
+    def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
+                            sampling_params: SamplingParams) -> None:
+        """Stop the finished sequences.
+       new_char_count is the number of chars added to the
+           sequence's output text for the newly generated token
+        """
+
+        # Check if the minimum number of tokens has been generated yet;
+        # skip the stop string/token checks if not
+        if seq.get_output_len() < sampling_params.min_tokens:
+            return
+
+        # Check if the sequence has generated the EOS token.
+        if ((not sampling_params.ignore_eos)
+                and seq.get_last_token_id() == seq.eos_token_id):
+            seq.status = SequenceStatus.FINISHED_STOPPED
+            return
+
+        # Check if a stop token was encountered.
+        # This assumes a single token produced per step.
+        last_token_id = seq.get_last_token_id()
+        if last_token_id in sampling_params.stop_token_ids:
+            if new_char_count and (
+                    not sampling_params.include_stop_str_in_output):
+                # Remove last token
+                seq.output_text = seq.output_text[:-new_char_count]
+            seq.status = SequenceStatus.FINISHED_STOPPED
+            seq.stop_reason = last_token_id
+            return
+
+        # Check if any stop strings are matched.
+        stop_str = self._check_stop_strings(seq, new_char_count,
+                                            sampling_params)
+        if stop_str is not None:
+            seq.status = SequenceStatus.FINISHED_STOPPED
+            seq.stop_reason = stop_str
+            return
+
+        # Check if the sequence has reached max_model_len.
+        if seq.get_len() > self.max_model_len:
+            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
+            return
+
+        # Check if the sequence has reached max_tokens.
+        if seq.get_output_len() == sampling_params.max_tokens:
+            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
+            return
+
+    @staticmethod
+    def _check_stop_strings(seq: Sequence, new_char_count: int,
+                            sampling_params: SamplingParams) -> Optional[str]:
+        """Check if any stop strings are matched and truncate sequence
+        output text accordingly.
+        Returns the stop string if matched or else None.
+        """
+        if not new_char_count:
+            return None
+
+        for stop_str in sampling_params.stop:
+            stop_string_len = len(stop_str)
+            # Avoid searching already-searched text.
+            stop_index = seq.output_text.find(
+                stop_str, -new_char_count - stop_string_len)
+            if stop_index == -1:
+                continue
+
+            if sampling_params.include_stop_str_in_output:
+                # Truncate to end of stop string.
+                stop_index += stop_string_len
+                if stop_index >= len(seq.output_text):
+                    # No truncation required.
+                    return stop_str
+
+            # Truncate the output text to either the beginning
+            # or end of the stop string.
+            seq.output_text = seq.output_text[:stop_index]
+            return stop_str
+        return None

+ 16 - 0
aphrodite/engine/output_processor/util.py

@@ -0,0 +1,16 @@
+from typing import List
+
+from aphrodite.common.sequence import SamplerOutput
+
+
+def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
+                                    num_seq_groups: int):
+    """Helper method which transforms a 2d list organized by
+    [step][sequence group] into [sequence group][step].
+    """
+    output_by_sequence_group = [[] for _ in range(num_seq_groups)]
+    for step in sampler_outputs:
+        for i, sequence_group_output in enumerate(step):
+            output_by_sequence_group[i].append(sequence_group_output)
+
+    return output_by_sequence_group

+ 162 - 0
aphrodite/executor/cpu_executor.py

@@ -0,0 +1,162 @@
+import os
+from typing import Dict, List, Set
+
+import torch
+from loguru import logger
+
+from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (get_distributed_init_method, get_ip,
+                                    get_open_port, make_async)
+from aphrodite.executor.executor_base import ExecutorBase
+from aphrodite.lora.request import LoRARequest
+
+
+class CPUExecutor(ExecutorBase):
+
+    def _init_executor(self) -> None:
+        assert self.device_config.device_type == "cpu"
+        assert self.lora_config is None, "cpu backend doesn't support LoRA"
+        self.model_config = _verify_and_get_model_config(self.model_config)
+        self.cache_config = _verify_and_get_cache_config(self.cache_config)
+        self.scheduler_config = _verify_and_get_scheduler_config(
+            self.scheduler_config)
+
+        # Instantiate the worker and load the model to CPU.
+        self._init_worker()
+
+    def _init_worker(self):
+        from aphrodite.task_handler.cpu_worker import CPUWorker
+
+        assert (self.parallel_config.world_size == 1
+                ), "CPUExecutor only supports single CPU socket currently."
+
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
+        self.driver_worker = CPUWorker(
+            model_config=self.model_config,
+            parallel_config=self.parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            kv_cache_dtype=self.cache_config.cache_dtype,
+            is_driver_worker=True,
+        )
+        self.driver_worker.init_device()
+        self.driver_worker.load_model()
+
+    def determine_num_available_blocks(self) -> tuple[int, int]:
+        """Determine the number of available KV blocks by invoking the
+        underlying worker.
+        """
+        return self.driver_worker.determine_num_available_blocks()
+
+    def initialize_cache(self, num_gpu_blocks: int,
+                         num_cpu_blocks: int) -> None:
+        """Initialize the KV cache by invoking the underlying worker.
+        """
+        # NOTE: We log here to avoid multiple logs when number of workers is
+        # greater than one. We could log in the engine, but not all executors
+        # have GPUs.
+
+        # NOTE: `cpu block` for CPU backend is located on CPU memory but is
+        # referred as `gpu block`. Because we want to reuse the existing block
+        # management procedure.
+        logger.info(f"# CPU blocks: {num_gpu_blocks}")
+        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]],
+                      num_lookahead_slots: int) -> List[SamplerOutput]:
+        output = self.driver_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            num_lookahead_slots=num_lookahead_slots,
+        )
+        return output
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        num_lookahead_slots: int,
+    ) -> SamplerOutput:
+        output = await make_async(self.driver_worker.execute_model)(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            num_lookahead_slots=num_lookahead_slots,
+        )
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        return self.driver_worker.add_lora(lora_request)
+
+    def remove_lora(self, lora_id: int) -> bool:
+        return self.driver_worker.remove_lora(lora_id)
+
+    def list_loras(self) -> Set[int]:
+        return self.driver_worker.list_loras()
+
+    def check_health(self) -> None:
+        # CPUExecutor will always be healthy as long as
+        # it's running.
+        return
+
+
+def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
+    if config.dtype == torch.float16:
+        logger.warning("float16 is not supported on CPU, casting to bfloat16.")
+        config.dtype = torch.bfloat16
+    if not config.enforce_eager:
+        logger.warning(
+            "CUDA graph is not supported on CPU, fallback to the eager "
+            "mode.")
+        config.enforce_eager = True
+    return config
+
+
+def _verify_and_get_scheduler_config(
+        config: SchedulerConfig) -> SchedulerConfig:
+    if config.chunked_prefill_enabled:
+        logger.warning("Chunked prefill is not supported on CPU, disable it.")
+        config.chunked_prefill_enabled = False
+
+    return config
+
+
+def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
+    _GB = 1 << 30
+    if config.context_shift:
+        logger.warning("Prefix caching is not supported on CPU, disable it.")
+        config.context_shift = False
+
+    kv_cache_space_str = os.getenv("APHRODITE_CPU_KVCACHE_SPACE", "0")
+    kv_cache_space = int(kv_cache_space_str)
+
+    if kv_cache_space >= 0:
+        if kv_cache_space == 0:
+            config.cpu_kvcache_space_bytes = 4 * _GB  # type: ignore
+            logger.warning(
+                "Environment variable APHRODITE_CPU_KVCACHE_SPACE (GB) "
+                "for CPU backend is not set, using 4 by default.")
+        else:
+            config.cpu_kvcache_space_bytes = kv_cache_space * _GB  # type: ignore
+    else:
+        raise RuntimeError(
+            "Invalid environment variable APHRODITE_CPU_KVCACHE_SPACE"
+            f" {kv_cache_space}, expect a positive integer value.")
+
+    return config

+ 47 - 15
aphrodite/executor/executor_base.py

@@ -1,14 +1,10 @@
 from abc import ABC, abstractmethod
-from typing import Dict, List, Optional
-
-from aphrodite.common.config import (
-    CacheConfig,
-    DeviceConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-)
+from typing import Dict, List, Optional, Set, Tuple
+
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig, VisionLanguageConfig,
+                                     SpeculativeConfig)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 
@@ -20,7 +16,6 @@ class ExecutorBase(ABC):
     that can execute the model on multiple devices.
     """
 
-    @abstractmethod
     def __init__(
         self,
         model_config: ModelConfig,
@@ -29,7 +24,43 @@ class ExecutorBase(ABC):
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig],
+        speculative_config: Optional[SpeculativeConfig],
     ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.vision_language_config = vision_language_config
+        self.speculative_config = speculative_config
+
+        self._init_executor()
+
+    @abstractmethod
+    def _init_executor(self) -> None:
+        pass
+
+    @abstractmethod
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
+        """Determine the number of available blocks for the GPU KV cache and
+        swappable CPU KV cache.
+        Normally, this should simply delegate to the underlying Worker. Some
+        ExecutorBase may require modification of the result, e.g. to ensure the
+        selected cache sizes are compatible with all workers.
+        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
+        are blocks that are "active" on the device and can be appended to.
+        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
+        appended to.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def initialize_cache(self, num_gpu_blocks: int,
+                         num_cpu_blocks: int) -> None:
+        """Initialize the KV cache with the given size in blocks.
+        """
         raise NotImplementedError
 
     @abstractmethod
@@ -39,7 +70,8 @@ class ExecutorBase(ABC):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
-    ) -> SamplerOutput:
+        num_lookahead_slots: int,
+    ) -> List[SamplerOutput]:
         """Executes one model step on the given sequences."""
         raise NotImplementedError
 
@@ -52,7 +84,7 @@ class ExecutorBase(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def list_loras(self) -> List[int]:
+    def list_loras(self) -> Set[int]:
         raise NotImplementedError
 
     @abstractmethod
@@ -71,12 +103,12 @@ class ExecutorAsyncBase(ExecutorBase):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
+        num_lookahead_slots: int,
     ) -> SamplerOutput:
         """Executes one model step on the given sequences."""
         raise NotImplementedError
 
-    @abstractmethod
     async def check_health_async(self) -> None:
         """Checks if the executor is healthy. If not, it should raise an
         exception."""
-        raise NotImplementedError
+        self.check_health()

+ 74 - 91
aphrodite/executor/gpu_executor.py

@@ -1,19 +1,9 @@
-import importlib
-from typing import Dict, List, Optional
+from typing import Dict, List, Set
 
 from loguru import logger
 
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (
-    CacheConfig,
-    DeviceConfig,
-    ModelConfig,
-    ParallelConfig,
-    SchedulerConfig,
-    LoRAConfig,
-)
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
-from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import (
     get_ip,
@@ -22,48 +12,23 @@ from aphrodite.common.utils import (
     make_async,
 )
 
-# A map between the device type (in device config) to its worker module.
-DEVICE_TO_WORKER_MODULE_MAP = {
-    "cuda": "aphrodite.task_handler.worker",
-    "neuron": "aphrodite.task_handler.neuron_worker",
-}
-
 
 class GPUExecutor(ExecutorBase):
 
-    def __init__(
-        self,
-        model_config: ModelConfig,
-        cache_config: CacheConfig,
-        parallel_config: ParallelConfig,
-        scheduler_config: SchedulerConfig,
-        device_config: DeviceConfig,
-        lora_config: Optional[LoRAConfig],
-    ) -> None:
-        self.model_config = model_config
-        self.cache_config = cache_config
-        self.lora_config = lora_config
-        self.parallel_config = parallel_config
-        self.scheduler_config = scheduler_config
-        self.device_config = device_config
-
-        # Instantiate the worker and load the model to GPU.
-        self._init_worker()
-
-        # Profile the memory usage and initialize the cache.
-        self._init_cache()
-
-    def _dispatch_worker(self):
-        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
-            self.device_config.device_type]
-        imported_worker = importlib.import_module(worker_module)
-        Worker = imported_worker.Worker
-        return Worker
-
-    def _init_worker(self):
+    def _init_executor(self) -> None:
+        """Initialize the worker and load the model.
+        If speculative decoding is enabled, we instead create the speculative
+        worker.
+        """
+        if self.speculative_config is None:
+            self._init_non_spec_worker()
+        else:
+            self._init_spec_worker()
+
+    def _init_non_spec_worker(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
+        from aphrodite.task_handler.worker import Worker
 
         assert (self.parallel_config.world_size == 1
                 ), "GPUExecutor only supports single GPU."
@@ -71,39 +36,71 @@ class GPUExecutor(ExecutorBase):
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
         self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
+            model_config=self.model_config,
+            parallel_config=self.parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
             local_rank=0,
             rank=0,
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
-            kv_cache_dtype=self.cache_config.cache_dtype,
+            vision_language_config=self.vision_language_config,
             is_driver_worker=True,
         )
-        self.driver_worker.init_model()
+        self.driver_worker.init_device()
         self.driver_worker.load_model()
 
-    def _init_cache(self) -> None:
-        """Profiles the memory usage and initializes the KV cache.
-        The engine first profiles the existing memory usage.
-        Then, it allocates the remaining memory for KV blocks.
-        .. tip::
-            You may limit the usage of GPU memory
-            by adjusting the `gpu_memory_utilization` parameter.
+    def _init_spec_worker(self):
+        """Initialize a SpecDecodeWorker, using a draft model for proposals.
         """
-        # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        (
-            num_gpu_blocks,
-            num_cpu_blocks,
-        ) = self.driver_worker.profile_num_available_blocks(
-            block_size=self.cache_config.block_size,
-            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
-            cpu_swap_space=self.cache_config.swap_space_bytes,
-            cache_dtype=self.cache_config.cache_dtype,
+        assert self.speculative_config is not None
+
+        from aphrodite.spec_decode.spec_decode_worker import SpecDecodeWorker
+        from aphrodite.task_handler.worker import Worker
+
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
+
+        target_worker = Worker(
+            model_config=self.model_config,
+            parallel_config=self.parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            vision_language_config=self.vision_language_config,
+            is_driver_worker=True,
         )
 
+        spec_decode_worker = SpecDecodeWorker.create_worker(
+            scorer_worker=target_worker,
+            speculative_config=self.speculative_config,
+        )
+
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
+
+        self.driver_worker = spec_decode_worker
+
+        # Load model handled in spec decode worker.
+        self.driver_worker.init_device()
+
+    def determine_num_available_blocks(self) -> tuple[int, int]:
+        """Determine the number of available KV blocks by invoking the
+        underlying worker.
+        """
+        return self.driver_worker.determine_num_available_blocks()
+
+    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
+        """Initialize the KV cache by invoking the underlying worker.
+        """
+        # NOTE: This is logged in the executor because there can be >1 worker
+        # with other executors. We could log in the engine level, but work
+        # remains to abstract away the device for non-GPU configurations.
         logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                     f"# CPU blocks: {num_cpu_blocks}")
 
@@ -111,20 +108,7 @@ class GPUExecutor(ExecutorBase):
             f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
         )
 
-        check_block_size_valid(
-            num_gpu_blocks,
-            self.cache_config.block_size,
-            self.model_config.max_model_len,
-        )
-
-        self.cache_config.num_gpu_blocks = num_gpu_blocks
-        self.cache_config.num_cpu_blocks = num_cpu_blocks
-
-        # Initialize the cache.
-        self.driver_worker.init_cache_engine(cache_config=self.cache_config)
-        # Warm up the model. This includes capturing the model into CUDA graph
-        # if enforce_eager is False.
-        self.driver_worker.warm_up_model()
+        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
 
     def execute_model(
         self,
@@ -132,12 +116,14 @@ class GPUExecutor(ExecutorBase):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
-    ) -> SamplerOutput:
+        num_lookahead_slots: int,
+    ) -> List[SamplerOutput]:
         output = self.driver_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
             blocks_to_copy=blocks_to_copy,
+            num_lookahead_slots=num_lookahead_slots,
         )
         return output
 
@@ -149,7 +135,7 @@ class GPUExecutor(ExecutorBase):
         assert lora_id > 0, "lora_id must be greater than 0."
         return self.driver_worker.remove_lora(lora_id)
 
-    def list_loras(self) -> List[int]:
+    def list_loras(self) -> Set[int]:
         return self.driver_worker.list_loras()
 
     def check_health(self) -> None:
@@ -166,16 +152,13 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
+        num_lookahead_slots: int,
     ) -> SamplerOutput:
         output = await make_async(self.driver_worker.execute_model)(
             seq_group_metadata_list=seq_group_metadata_list,
             blocks_to_swap_in=blocks_to_swap_in,
             blocks_to_swap_out=blocks_to_swap_out,
             blocks_to_copy=blocks_to_copy,
+            num_lookahead_slots=num_lookahead_slots,
         )
         return output
-
-    async def check_health_async(self) -> None:
-        # GPUExecutor will always be healthy as long as
-        # it's running.
-        return

+ 27 - 33
aphrodite/executor/neuron_executor.py

@@ -1,36 +1,17 @@
-from typing import Dict, List, Optional
+from typing import Dict, List, Set
 
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 
 
 class NeuronExecutor(ExecutorBase):
 
-    def __init__(
-        self,
-        model_config: ModelConfig,
-        cache_config: CacheConfig,
-        parallel_config: ParallelConfig,
-        scheduler_config: SchedulerConfig,
-        device_config: DeviceConfig,
-        lora_config: Optional[LoRAConfig],
-    ) -> None:
-        self.model_config = model_config
-        self.cache_config = cache_config
-        assert lora_config is None, "LoRA is not supported for Neuron backend."
-        self.parallel_config = parallel_config
-        self.scheduler_config = scheduler_config
-        self.device_config = device_config
-
-        # Set the number of GPU blocks to be the same as the maximum number of
-        # sequences that can be processed in a single batch. This is equivalent
-        # to schedule without PagedAttention.
-        self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
-        self.cache_config.num_cpu_blocks = 0
+    def _init_executor(self) -> None:
+        assert (self.lora_config is
+                None), "LoRA is not supported for Neuron backend."
+        assert (not self.speculative_config
+                ), "Speculative decoding not yet supported for Neuron backend."
 
         # Instantiate the worker and load the model to the device.
         self._init_worker()
@@ -43,34 +24,47 @@ class NeuronExecutor(ExecutorBase):
             self.parallel_config,
             self.scheduler_config,
             self.device_config,
+            self.cache_config,
         )
         self.driver_worker.init_device()
         self.driver_worker.load_model()
 
+    def determine_num_available_blocks(self) -> tuple[int, int]:
+        """Determine the number of available KV blocks by invoking the
+        underlying worker.
+        """
+        return self.driver_worker.determine_num_available_blocks()
+
+    def initialize_cache(self, num_gpu_blocks: int,
+                         num_cpu_blocks: int) -> None:
+        """Initialize the KV cache by invoking the underlying worker.
+        """
+        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
     def execute_model(self,
                       seq_group_metadata_list: List[SequenceGroupMetadata],
                       blocks_to_swap_in: Dict[int, int],
                       blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+                      blocks_to_copy: Dict[int, List[int]],
+                      num_lookahead_slots: int) -> List[SamplerOutput]:
         assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
                 and blocks_to_copy == {}), (
                     "Cache operations are not supported for Neuron backend.")
+        assert num_lookahead_slots == 0, (
+            "lookahead not supported for Neuron backend.")
 
         output = self.driver_worker.execute_model(
             seq_group_metadata_list=seq_group_metadata_list)
         return output
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
-        raise NotImplementedError(
-            "LoRA is not implemented for neuron backend.")
+        return self.driver_worker.add_lora(lora_request)
 
     def remove_lora(self, lora_id: int) -> bool:
-        raise NotImplementedError(
-            "LoRA is not implemented for neuron backend.")
+        return self.driver_worker.remove_lora(lora_id)
 
-    def list_loras(self) -> List[int]:
-        raise NotImplementedError(
-            "LoRA is not implemented for neuron backend.")
+    def list_loras(self) -> Set[int]:
+        return self.driver_worker.list_loras()
 
     def check_health(self) -> None:
         # NeuronExecutor will always be healthy as long as

+ 70 - 104
aphrodite/executor/ray_gpu_executor.py

@@ -3,17 +3,12 @@ import copy
 from collections import defaultdict
 import os
 import pickle
-import importlib
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Set, Optional
 
 from loguru import logger
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
-                                     ParallelConfig, SchedulerConfig,
-                                     LoRAConfig)
 from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
 from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
-from aphrodite.executor.utils import check_block_size_valid
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import (set_cuda_visible_devices, get_ip,
@@ -26,12 +21,6 @@ if ray is not None:
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
 
-# A map between the device type (in device config) to its worker module.
-DEVICE_TO_WORKER_MODULE_MAP = {
-    "cuda": "aphrodite.task_handler.worker",
-    "neuron": "aphrodite.task_handler.neuron_worker",
-}
-
 # If the env var is set, it uses the Ray's compiled DAG API
 # which optimizes the control plane overhead.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
@@ -40,21 +29,9 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
 
 class RayGPUExecutor(ExecutorBase):
 
-    def __init__(
-        self,
-        model_config: ModelConfig,
-        cache_config: CacheConfig,
-        parallel_config: ParallelConfig,
-        scheduler_config: SchedulerConfig,
-        device_config: DeviceConfig,
-        lora_config: Optional[LoRAConfig],
-    ) -> None:
-        self.model_config = model_config
-        self.cache_config = cache_config
-        self.lora_config = lora_config
-        self.parallel_config = parallel_config
-        self.scheduler_config = scheduler_config
-        self.device_config = device_config
+    def _init_executor(self) -> None:
+        assert (not self.speculative_config
+                ), "Speculative decoding not yet supported for RayGPU backend."
 
         assert self.parallel_config.worker_use_ray
         placement_group = self.parallel_config.placement_group
@@ -67,19 +44,24 @@ class RayGPUExecutor(ExecutorBase):
         # Create the parallel GPU workers.
         self._init_workers_ray(placement_group)
 
-        # Profile the memory usage and initialize the cache.
-        self._init_cache()
-
         self.forward_dag = None
         if USE_RAY_COMPILED_DAG:
             self.forward_dag = self._compiled_ray_dag()
 
-    def _dispatch_worker(self):
-        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
-            self.device_config.device_type]
-        imported_worker = importlib.import_module(worker_module)
-        Worker = imported_worker.Worker
-        return Worker
+    def _configure_ray_workers_use_nsight(self,
+                                          ray_remote_kwargs) -> Dict[str, Any]:
+        # If nsight profiling is enabled, we need to set the profiling
+        # configuration for the ray workers as runtime env.
+        runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
+        runtime_env.update({
+            "nsight": {
+                "t": "cuda,cudnn,cublas",
+                "o": "'worker_process_%p'",
+                "cuda-graph-trace": "node",
+            }
+        })
+
+        return ray_remote_kwargs
 
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
@@ -96,6 +78,10 @@ class RayGPUExecutor(ExecutorBase):
         # The remaining workers are the actual ray actors.
         self.workers: List[RayWorkerAphrodite] = []
 
+        if self.parallel_config.ray_workers_use_nsight:
+            ray_remote_kwargs = self._configure_ray_workers_use_nsight(
+                ray_remote_kwargs)
+
         # Create the workers.
         driver_ip = get_ip()
         for bundle_id, bundle in enumerate(placement_group.bundle_specs):
@@ -156,14 +142,15 @@ class RayGPUExecutor(ExecutorBase):
 
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        Worker = self._dispatch_worker()
+        from aphrodite.task_handler.worker import Worker
 
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
         device_config = copy.deepcopy(self.device_config)
         lora_config = copy.deepcopy(self.lora_config)
-        kv_cache_dtype = self.cache_config.cache_dtype
+        cache_config = copy.deepcopy(self.cache_config)
+        vision_language_config = copy.deepcopy(self.vision_language_config)
 
         # Initialize the actual workers with the Worker class.
         for rank, (worker, (node_id, _)) in enumerate(
@@ -173,80 +160,68 @@ class RayGPUExecutor(ExecutorBase):
             local_rank = node_workers[node_id].index(rank)
             worker.init_worker.remote(
                 lambda rank=rank, local_rank=local_rank: Worker(
-                    model_config,
-                    parallel_config,
-                    scheduler_config,
-                    device_config,
-                    local_rank,
-                    rank,
-                    distributed_init_method,
+                    model_config=model_config,
+                    parallel_config=parallel_config,
+                    scheduler_config=scheduler_config,
+                    device_config=device_config,
+                    cache_config=cache_config,
+                    local_rank=local_rank,
+                    rank=rank,
+                    distributed_init_method=distributed_init_method,
                     lora_config=lora_config,
-                    kv_cache_dtype=kv_cache_dtype,
+                    vision_language_config=vision_language_config,
                 ))
 
         # Initialize the driver worker with the Worker class.
         driver_rank = 0
         driver_local_rank = node_workers[driver_node_id].index(driver_rank)
         self.driver_worker = Worker(
-            self.model_config,
-            self.parallel_config,
-            self.scheduler_config,
-            self.device_config,
-            driver_local_rank,
-            driver_rank,
-            distributed_init_method,
+            model_config=self.model_config,
+            parallel_config=self.parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
+            local_rank=driver_local_rank,
+            rank=driver_rank,
+            distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
-            kv_cache_dtype=kv_cache_dtype,
+            vision_language_config=self.vision_language_config,
             is_driver_worker=True,
         )
 
-        # FIXME: We are not properly initializing cupy NCCL when
-        # we have multiple nodes.
-        self._run_workers("init_model",
-                          cupy_port=get_open_port()
-                          if not model_config.enforce_eager else None)
+        self._run_workers("init_device")
         self._run_workers(
             "load_model",
             max_concurrent_workers=self.parallel_config.
             max_parallel_loading_workers,
         )
 
-    def _init_cache(self) -> None:
-        # ruff: noqa: E501
-        """Profiles the memory usage and initializes the KV cache.
-
-        The engine will first conduct a profiling of the existing memory usage.
-        Then, it calculate the maximum possible number of GPU and CPU blocks
-        that can be allocated with the remaining free memory.
-        More details can be found in the
-        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
-        from class :class:`~aphrodite.task_handler.Worker`.
-
-        Afterwards, as there may be multiple workers,
-        we take the minimum number of blocks across all workers
-        to ensure this can be applied to all of them.
-
-        Finally, the engine will initialize the KV cache
-        with the calculated number of blocks.
-
-        .. tip::
-            You may limit the usage of GPU memory
-            by adjusting the `gpu_memory_utilization` parameter.
+    def determine_num_available_blocks(self) -> tuple[int, int]:
+        """Determine the number of available KV blocks.
+        This invokes `determine_num_available_blocks` on each worker and takes
+        the min of the results, guaranteeing that the selected cache sizes are
+        compatible with all workers.
+        Returns:
+            - tuple[num_gpu_blocks, num_cpu_blocks]
         """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
-        num_blocks = self._run_workers(
-            "profile_num_available_blocks",
-            block_size=self.cache_config.block_size,
-            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
-            cpu_swap_space=self.cache_config.swap_space_bytes,
-            cache_dtype=self.cache_config.cache_dtype,
-        )
+        num_blocks = self._run_workers("determine_num_available_blocks", )
 
         # Since we use a shared centralized controller, we take the minimum
         # number of blocks across all workers to make sure all the memory
         # operators can be applied to all workers.
         num_gpu_blocks = min(b[0] for b in num_blocks)
         num_cpu_blocks = min(b[1] for b in num_blocks)
+        return num_gpu_blocks, num_cpu_blocks
+
+    def initialize_cache(self, num_gpu_blocks: int,
+                         num_cpu_blocks: int) -> None:
+        """Initialize the KV cache in all workers.
+        """
+
+        # NOTE: We log here to avoid multiple logs when number of workers is
+        # greater than one. We could log in the engine, but not all executors
+        # have GPUs.
         logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                     f"# CPU blocks: {num_cpu_blocks}")
 
@@ -254,26 +229,19 @@ class RayGPUExecutor(ExecutorBase):
             f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
         )
 
-        check_block_size_valid(
-            num_gpu_blocks,
-            self.cache_config.block_size,
-            self.model_config.max_model_len,
-        )
-
         self.cache_config.num_gpu_blocks = num_gpu_blocks
         self.cache_config.num_cpu_blocks = num_cpu_blocks
 
-        # Initialize the cache.
-        self._run_workers("init_cache_engine", cache_config=self.cache_config)
-        # Warm up the model. This includes capturing the model into CUDA graph
-        # if enforce_eager is False.
-        self._run_workers("warm_up_model")
+        self._run_workers("initialize_cache",
+                          num_gpu_blocks=num_gpu_blocks,
+                          num_cpu_blocks=num_cpu_blocks)
 
     def execute_model(self,
                       seq_group_metadata_list: List[SequenceGroupMetadata],
                       blocks_to_swap_in: Dict[int, int],
                       blocks_to_swap_out: Dict[int, int],
-                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+                      blocks_to_copy: Dict[int, List[int]],
+                      num_lookahead_slots: int = 0) -> List[SamplerOutput]:
         all_outputs = self._run_workers(
             "execute_model",
             driver_kwargs={
@@ -302,7 +270,7 @@ class RayGPUExecutor(ExecutorBase):
             lora_id=lora_id,
         )
 
-    def list_loras(self) -> List[int]:
+    def list_loras(self) -> Set[int]:
         return self._run_workers("list_loras")
 
     def _run_workers(
@@ -431,6 +399,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
+        num_lookahead_slots: int = 0,
     ) -> SamplerOutput:
         all_outputs = await self._run_workers_async(
             "execute_model",
@@ -439,12 +408,9 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
                 "blocks_to_swap_in": blocks_to_swap_in,
                 "blocks_to_swap_out": blocks_to_swap_out,
                 "blocks_to_copy": blocks_to_copy,
+                "num_lookahead_slots": num_lookahead_slots,
             })
 
         # Only the driver worker returns the sampling results.
         output = all_outputs[0]
         return output
-
-    async def check_health_async(self) -> None:
-        """Raises an error if engine is unhealthy."""
-        self._check_if_any_actor_is_dead()

+ 0 - 18
aphrodite/executor/utils.py

@@ -1,18 +0,0 @@
-from loguru import logger
-
-
-def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
-    if num_gpu_blocks <= 0:
-        raise ValueError("No available memory for the cache blocks. "
-                         "Try increasing `gpu_memory_utilization` when "
-                         "initializing the engine.")
-    max_seq_len = block_size * num_gpu_blocks
-    logger.info(f"Maximum sequence length allowed in the cache: "
-                f"{max_seq_len}")
-    if max_model_len > max_seq_len:
-        raise ValueError(
-            f"The model's max seq len ({max_model_len}) "
-            "is larger than the maximum number of tokens that can be "
-            f"stored in KV cache ({max_seq_len}). Try increasing "
-            "`gpu_memory_utilization` or decreasing `max_model_len` when "
-            "initializing the engine.")

+ 180 - 54
aphrodite/lora/layers.py

@@ -1,7 +1,8 @@
 # pylint: disable=unused-argument
+import inspect
 import math
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
 
 import torch
 import torch.nn as nn
@@ -9,22 +10,20 @@ import torch.nn.functional as F
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import LoRAConfig
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   split_tensor_along_last_dim,
+                                   tensor_model_parallel_all_gather,
+                                   tensor_model_parallel_all_reduce,
+                                   tensor_model_parallel_gather)
 from aphrodite.lora.punica import add_lora, add_lora_slice, bgmv
-from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_gather,
-    tensor_model_parallel_all_reduce,
-    tensor_model_parallel_gather,
-)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              RowParallelLinear,
+                                              MergedColumnParallelLinear,
                                               QKVParallelLinear,
-                                              MergedColumnParallelLinear)
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.utils import split_tensor_along_last_dim
+    ParallelLMHead, VocabParallelEmbedding)
 
 if TYPE_CHECKING:
     pass
@@ -108,8 +107,8 @@ def _apply_lora_packed_nslice(
         lora_b_stacked:    3 element tuple of (num_loras, output_dim, lora_rank)
         indices:           (batch_size)
         output:            (batch_size, q_slice_size + 2*kv_slice_size)
-        output_slices:     n-1 element tuple of (slice_size...), where n is
-                           number of slices
+        output_slices:     n-1 element tuple of (slice_size...),
+                           where n is number of slices
     """
     org_output = output
     x = x.view(-1, x.shape[-1])
@@ -138,8 +137,11 @@ class LoRAMapping:
 
 class BaseLayerWithLoRA(nn.Module):
 
-    def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
-                            model_config: PretrainedConfig) -> None:
+    def create_lora_weights(
+            self,
+            max_loras: int,
+            lora_config: LoRAConfig,
+            model_config: Optional[PretrainedConfig] = None) -> None:
         """Initializes lora matrices."""
         ...
 
@@ -168,6 +170,13 @@ class BaseLayerWithLoRA(nn.Module):
         """Sets the mapping indices."""
         ...
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        """Returns True if the layer can be replaced by this LoRA layer."""
+        raise NotImplementedError
+
 
 class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
 
@@ -281,12 +290,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         added_tokens_mask = x > self.base_layer.org_vocab_size - 1
-        indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
+        embedding_len = self.indices_len[3]
+        indices = self.embeddings_indices[1][:embedding_len].view_as(x)
         full_lora_a_embeddings = F.embedding(
             x + indices,
             self.lora_a_stacked_2d,
         )
-        indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
+        indices = self.embeddings_indices[0][:embedding_len].view_as(x)
         full_output = self.base_layer.forward(
             x.add_(indices * added_tokens_mask))
 
@@ -302,12 +312,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
              self.indices[:self.indices_len[0]], 0, 1.0)
         return full_output.view_as(full_output_org)
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is VocabParallelEmbedding
+
 
 class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
 
     def __init__(self, base_layer: ColumnParallelLinear) -> None:
         super().__init__()
         self.base_layer = base_layer
+        self.tp_size = get_tensor_model_parallel_world_size()
 
     def create_lora_weights(
             self,
@@ -333,7 +350,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
 
         self.indices: Optional[torch.Tensor] = None
         self.indices_len: Optional[List[int]] = None
-        self.output_dim = self.lora_b_stacked.shape[1]
+        self.output_dim = self.lora_b_stacked.shape[2]
 
     def reset_lora(self, index: int):
         self.lora_a_stacked[index] = 0
@@ -347,7 +364,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
         embeddings_tensor: Optional[torch.Tensor],
     ):
         self.reset_lora(index)
-
+        if self.tp_size > 1:
+            tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+            shard_size = self.output_dim
+            start_idx = tensor_model_parallel_rank * shard_size
+            end_idx = (tensor_model_parallel_rank + 1) * shard_size
+            lora_b = lora_b[:, start_idx:end_idx]
         self.lora_a_stacked[index,
                             0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
                                 lora_a.T, non_blocking=True)
@@ -407,6 +429,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
     def linear_weights(self):
         return self.base_layer.linear_weights
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is ColumnParallelLinear or (
+            type(source_layer) is MergedColumnParallelLinear
+            and len(packed_modules_list) == 1)
+
 
 class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
     """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
@@ -510,8 +540,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
         )
         return output
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is MergedColumnParallelLinear and len(
+            packed_modules_list) == 2
+
 
 class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
+    """
+    ColumnParallelLinear layer that is specifically designed for  
+    qkv_proj. Certain models, such as chtglm3 and baichuan-7b,  
+    only contains a single LoRA within their qkv_proj layer. 
+
+    During inference with Tensor Parallel, the weights of lora_b 
+    must be accurately partitioned according to the respective ranks.
+    
+    Q slice may have different shape than K and V slices (which both have
+    the same shape).
+    """
+
+    def __init__(self, base_layer: QKVParallelLinear) -> None:
+        super().__init__(base_layer)
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.q_proj_total_size = (self.base_layer.total_num_heads *
+                                  self.base_layer.head_size)
+        self.q_proj_shard_size = (self.base_layer.num_heads *
+                                  self.base_layer.head_size)
+        self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
+                                   self.base_layer.head_size)
+        self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
+                                   self.base_layer.head_size)
+
+    def set_lora(
+        self,
+        index: int,
+        lora_a: torch.Tensor,
+        lora_b: torch.Tensor,
+        embeddings_tensor: Optional[torch.Tensor],
+    ):
+        self.reset_lora(index)
+        if self.tp_size > 1:
+            tp_rank = get_tensor_model_parallel_rank()
+            self.q_shard_id = tp_rank
+            self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
+            lora_b_q = lora_b[:, self.q_proj_shard_size *
+                              self.q_shard_id:self.q_proj_shard_size *
+                              (self.q_shard_id + 1)]
+            k_offset = self.q_proj_total_size
+            lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
+                              self.kv_shard_id:k_offset +
+                              self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+            v_offset = k_offset + self.kv_proj_total_size
+            lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
+                              self.kv_shard_id:v_offset +
+                              self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+            lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
+
+        self.lora_a_stacked[index,
+                            0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
+                                lora_a.T, non_blocking=True)
+        self.lora_b_stacked[index,
+                            0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
+                                lora_b.T, non_blocking=True)
+
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is QKVParallelLinear and len(
+            packed_modules_list) == 1
+
+
+class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
     """ColumnParallelLinear layer that is composed of 3 sublayers (slices)
     packed together in qkv proj fashion
     (q_proj + k_proj + v_proj -> qkv_proj).
@@ -680,6 +782,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
         )
         return output
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is QKVParallelLinear and len(
+            packed_modules_list) == 3
+
 
 class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
 
@@ -692,8 +801,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
             max_loras: int,
             lora_config: LoRAConfig,
             model_config: Optional[PretrainedConfig] = None) -> None:
-        device = _get_lora_device(self.base_layer)
 
+        device = _get_lora_device(self.base_layer)
         self.lora_a_stacked = torch.zeros(
             (
                 max_loras,
@@ -808,12 +917,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
     def weight(self):
         return self.base_layer.weight
 
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        return type(source_layer) is RowParallelLinear
 
-class SamplerWithLoRA(BaseLayerWithLoRA):
+
+class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
 
     def __init__(
         self,
-        base_layer: Sampler,
+        base_layer: LogitsProcessor,
         hidden_size: int,
         dtype: torch.dtype,
         device: torch.device,
@@ -825,13 +940,17 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
         self.device = device
 
     @property
-    def logits_as_hidden_states(self):
-        return self.base_layer.logits_as_hidden_states
+    def logits_as_input(self):
+        return self.base_layer.logits_as_input
 
     @property
     def vocab_size(self):
         return self.base_layer.vocab_size
 
+    @property
+    def scale(self):
+        return self.base_layer.scale
+
     @property
     def org_vocab_size(self):
         return self.base_layer.org_vocab_size
@@ -847,10 +966,9 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
         model_config: Optional[PretrainedConfig] = None,
     ) -> None:
         # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
-        if 32000 < self.base_layer.vocab_size > 33024:
-            raise ValueError(
-                "When using LoRA, vocab size must be 32000 >= vocab_size "
-                "<= 33024")
+        if 32000 < self.base_layer.vocab_size > 128512:
+            raise ValueError("When using LoRA, vocab size must be "
+                             "32000 >= vocab_size <= 128512")
         self.lora_a_stacked = torch.zeros(
             (
                 max_loras,
@@ -923,11 +1041,11 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
     def _get_logits(
         self,
         hidden_states: torch.Tensor,
-        embedding: torch.Tensor,
+        lm_head: torch.Tensor,
         embedding_bias: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
+    ) -> Optional[torch.Tensor]:
         # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
+        logits = lm_head(hidden_states)
         if embedding_bias is not None:
             logits += embedding_bias
         logits = tensor_model_parallel_gather(logits)
@@ -974,35 +1092,43 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
     def forward(self, *args, **kwargs):
         return type(self.base_layer).forward(self, *args, **kwargs)
 
-
-def from_layer(
-        layer: nn.Module,
-        max_loras: int,
-        lora_config: LoRAConfig,
-        model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
-    supported_layer_types = {
-        VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
-        ColumnParallelLinear: ColumnParallelLinearWithLoRA,
-        QKVParallelLinear: QKVParallelLinearWithLora,
-        MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
-        RowParallelLinear: RowParallelLinearWithLoRA,
-    }
-    for src_layer_type, lora_layer_type in supported_layer_types.items():
-        if type(layer) is src_layer_type:  # pylint: disable=unidiomatic-typecheck
-            ret = lora_layer_type(layer)
+    @classmethod
+    def can_replace_layer(cls, source_layer: nn.Module,
+                          lora_config: LoRAConfig, packed_modules_list: List,
+                          model_config: Optional[PretrainedConfig]) -> bool:
+        # Special handling for the LogitsProcessor.
+        return False
+
+
+_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
+    cls
+    for cls in globals().values() if inspect.isclass(cls)
+    and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
+}
+
+
+def from_layer(layer: nn.Module,
+               max_loras: int,
+               lora_config: LoRAConfig,
+               packed_modules_list: List,
+               model_config: Optional[PretrainedConfig] = None) -> nn.Module:
+    for lora_cls in _all_lora_classes:
+        if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
+                                      model_config):
+            ret = lora_cls(layer)
             ret.create_lora_weights(max_loras, lora_config, model_config)
             return ret
     return layer
 
 
-def from_layer_sampler(
-    layer: Sampler,
+def from_layer_logits_processor(
+    layer: LogitsProcessor,
     lm_head: ParallelLMHead,
     max_loras: int,
     lora_config: LoRAConfig,
     model_config: Optional[PretrainedConfig] = None,
-) -> SamplerWithLoRA:
-    ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
-                          lm_head.weight.device)
+) -> LogitsProcessorWithLoRA:
+    ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
+                                  lm_head.weight.dtype, lm_head.weight.device)
     ret.create_lora_weights(max_loras, lora_config, model_config)
     return ret

+ 3 - 2
aphrodite/lora/lora.py

@@ -1,7 +1,8 @@
 from typing import List, Optional
 
 import torch
-from aphrodite.common.utils import in_wsl
+
+from aphrodite.common.utils import is_pin_memory_available
 
 
 class LoRALayerWeights:
@@ -64,7 +65,7 @@ class LoRALayerWeights:
             dtype: torch.dtype,
             device: torch.device,
             embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
-        pin_memory = str(device) == "cpu" and not in_wsl()
+        pin_memory = str(device) == "cpu" and is_pin_memory_available()
         lora_a = torch.zeros([input_dim, rank],
                              dtype=dtype,
                              device=device,

+ 79 - 58
aphrodite/lora/models.py

@@ -1,25 +1,22 @@
 import copy
 import json
-import logging
 import math
 import os
 import re
-from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
+from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
 
 import safetensors.torch
 import torch
+from loguru import logger
 from torch import nn
 
 from aphrodite.common.config import LoRAConfig
-from aphrodite.common.utils import LRUCache, in_wsl
-
+from aphrodite.common.utils import LRUCache, is_pin_memory_available
 from aphrodite.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
-                                   from_layer_sampler)
+                                   from_layer_logits_processor)
 from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
 from aphrodite.lora.utils import parse_fine_tuned_lora_name, replace_submodule
 
-logger = logging.getLogger(__name__)
-
 _GLOBAL_LORA_ID = 0
 
 
@@ -143,48 +140,53 @@ class LoRAModel:
         embedding_padding_modules: Optional[List[str]] = None,
     ) -> "LoRAModel":
         """Create a LoRAModel from a dictionary of tensors."""
-        pin_memory = str(device) == "cpu" and not in_wsl()
+        pin_memory = str(device) == "cpu" and is_pin_memory_available()
         loras: Dict[str, LoRALayerWeights] = {}
         for tensor_name, tensor in tensors.items():
-            result = parse_fine_tuned_lora_name(tensor_name)
-            if result is not None:
-                module_name, is_lora_a = result
-                if module_name not in loras:
-                    lora_embeddings_tensor = None
-                    if embeddings:
-                        embeddings_module = next(
-                            (k for k in embedding_modules if k in module_name),
-                            None)
-                        if embeddings_module:
-                            lora_embeddings_tensor = embeddings[
-                                embedding_modules[embeddings_module]].to(
-                                    device=device, dtype=dtype)
-                            if pin_memory:
-                                lora_embeddings_tensor = (
-                                    lora_embeddings_tensor.pin_memory())
-                    loras[module_name] = LoRALayerWeights(
-                        module_name, rank, lora_alpha, None, None,
-                        lora_embeddings_tensor)
-                if is_lora_a:
-                    loras[module_name].lora_a = tensor.to(device=device,
-                                                          dtype=dtype).t()
-                    if pin_memory:
-                        loras[module_name].lora_a = loras[
-                            module_name].lora_a.pin_memory()
-                else:
-                    loras[module_name].lora_b = tensor.to(device=device,
-                                                          dtype=dtype).t()
-                    if any(name in module_name
-                           for name in embedding_padding_modules
-                           ) and target_embedding_padding is not None:
-                        lora_b = loras[module_name].lora_b
-                        assert target_embedding_padding >= lora_b.shape[1]
-                        addition = target_embedding_padding - lora_b.shape[1]
-                        loras[module_name].lora_b = torch.nn.functional.pad(
-                            lora_b, (0, addition))
-                    if pin_memory:
-                        loras[module_name].lora_b = loras[
-                            module_name].lora_b.pin_memory()
+            module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
+            if module_name not in loras:
+                lora_embeddings_tensor = None
+                if embeddings:
+                    embeddings_module = next(
+                        (k for k in embedding_modules if k in module_name),
+                        None)
+                    if embeddings_module:
+                        lora_embeddings_tensor = embeddings[
+                            embedding_modules[embeddings_module]].to(
+                                device=device, dtype=dtype)
+                        if pin_memory:
+                            lora_embeddings_tensor = (
+                                lora_embeddings_tensor.pin_memory())
+                loras[module_name] = LoRALayerWeights(module_name, rank,
+                                                      lora_alpha, None, None,
+                                                      lora_embeddings_tensor)
+            if is_lora_a:
+                loras[module_name].lora_a = tensor.to(device=device,
+                                                      dtype=dtype).t()
+                if pin_memory:
+                    loras[module_name].lora_a = loras[
+                        module_name].lora_a.pin_memory()
+            else:
+                loras[module_name].lora_b = tensor.to(device=device,
+                                                      dtype=dtype).t()
+                if any(name in module_name
+                       for name in embedding_padding_modules
+                       ) and target_embedding_padding is not None:
+                    lora_b = loras[module_name].lora_b
+                    assert target_embedding_padding >= lora_b.shape[1]
+                    addition = target_embedding_padding - lora_b.shape[1]
+                    loras[module_name].lora_b = torch.nn.functional.pad(
+                        lora_b, (0, addition))
+                if pin_memory:
+                    loras[module_name].lora_b = loras[
+                        module_name].lora_b.pin_memory()
+
+        # Filter out LoRALayerWeights instances where lora_a or lora_b is None
+        loras = {
+            k: v
+            for k, v in loras.items()
+            if v.lora_a is not None and v.lora_b is not None
+        }
 
         for lora in loras.values():
             lora.optimize()
@@ -194,6 +196,7 @@ class LoRAModel:
     def from_local_checkpoint(
         cls,
         lora_dir: str,
+        expected_lora_modules: List[str],
         lora_model_id: Optional[int] = None,
         device: str = "cuda",
         dtype: Optional[torch.dtype] = None,
@@ -209,6 +212,20 @@ class LoRAModel:
             lora_dir, "new_embeddings.safetensors")
         new_embeddings_bin_file_path = os.path.join(lora_dir,
                                                     "new_embeddings.bin")
+        with open(lora_config_path) as f:
+            config = json.load(f)
+        target_modules = config["target_modules"]
+        unexpected_modules = []
+        for module in target_modules:
+            if module not in expected_lora_modules:
+                unexpected_modules.append(module)
+        # loaded lora's target modules must be a subset of expected_lora_modules
+        if unexpected_modules:
+            raise ValueError(
+                f"While loading {lora_dir}, expected"
+                f" target modules in {expected_lora_modules}"
+                f" but received {unexpected_modules}."
+                f" Please verify that the loaded LoRA module is correct")
         if os.path.isfile(lora_tensor_path):
             tensors = safetensors.torch.load_file(lora_tensor_path)
         elif os.path.isfile(lora_bin_file_path):
@@ -223,8 +240,6 @@ class LoRAModel:
         elif os.path.isfile(new_embeddings_bin_file_path):
             embeddings = torch.load(new_embeddings_bin_file_path)
 
-        with open(lora_config_path) as f:
-            config = json.load(f)
         rank = config["r"]
         lora_alpha = config["lora_alpha"]
         return cls.from_lora_tensors(
@@ -284,7 +299,7 @@ class LoRAModelManager:
                                               dtype=torch.long,
                                               device="cuda")
         self.offsets = []
-        # 4 is the number of indices tensors defined above
+        # 4 is the number of indicies tensors defined above
         # base_indices, sampler_indices, sampler_indices_padded,
         # embeddings_indices
         self.indices_len = [None] * 4
@@ -416,18 +431,22 @@ class LoRAModelManager:
         for module_name, module in self.model.named_modules():
             if not self._match_target_modules(module_name):
                 continue
-
+            parts = module_name.split(".")[-1]
+            packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
             new_module = replace_submodule(
                 self.model, module_name,
                 from_layer(module, self.lora_slots, self.lora_config,
-                           self.model.config))
+                           packed_moduled_lst, self.model.config))
             # (yard1): TODO make this more robust
             if "lm_head" in module_name:
-                sampler_module = self.model.get_submodule("sampler")
+                logits_processor_module = self.model.get_submodule(
+                    "logits_processor")
                 new_module = replace_submodule(
-                    self.model, "sampler",
-                    from_layer_sampler(sampler_module, module, self.lora_slots,
-                                       self.lora_config, self.model.config))
+                    self.model, "logits_processor",
+                    from_layer_logits_processor(logits_processor_module,
+                                                module, self.lora_slots,
+                                                self.lora_config,
+                                                self.model.config))
             self.register_module(module_name, new_module)
             self._register_packed_modules(module_name)
             new_module.set_mapping(self.base_indices, self.sampler_indices,
@@ -510,8 +529,10 @@ class LoRAModelManager:
     def _register_packed_modules(self, module_full_name: str) -> None:
         parts = module_full_name.split(".")
         module_name = parts[-1]
-        replacements = self.packed_modules_mapping.get(module_name)
-        if not replacements:
+        replacements = self.packed_modules_mapping.get(module_name, [])
+        # When replacements is less than or equal to 1, it indicates that this
+        # module is not a packed module.
+        if len(replacements) <= 1:
             return
         prefix = ".".join(parts[:-1])
         self.packed_modules[module_full_name] = [

+ 4 - 0
aphrodite/lora/punica.py

@@ -31,6 +31,7 @@ def bgmv(
           @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
           * scale
         ).squeeze(0)
+
     Args:
       y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
       x: Shape: `[B, H1]`. Input vectors.
@@ -65,6 +66,7 @@ def add_lora(y: torch.Tensor,
           @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
           * scale
         ).squeeze(0)
+
     Args:
       y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
       x: Shape: `[B, H1]`. Input vectors.
@@ -109,6 +111,7 @@ def add_lora_slice(y: torch.Tensor,
     """
     Same as `add_lora` but you can operate on slices of y.
     Pass whole y, define y_offset and y_slice_size.
+
     Semantics:
       y[i] += (
           x[i].unsqueeze(0)
@@ -116,6 +119,7 @@ def add_lora_slice(y: torch.Tensor,
           @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
           * scale
         ).squeeze(0)
+
     Args:
       y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
       x: Shape: `[B, H1]`. Input vectors.

+ 3 - 4
aphrodite/lora/utils.py

@@ -1,10 +1,7 @@
-import logging
 from typing import Tuple
 
 from torch import nn
 
-logger = logging.getLogger(__name__)
-
 
 def replace_submodule(model: nn.Module, module_name: str,
                       new_module: nn.Module) -> nn.Module:
@@ -33,7 +30,9 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
         if parts[-2] == "lora_A" or parts[-2] == "lora_B":
             return ".".join(parts[2:-2]), parts[-2] == "lora_A"
         else:
-            return None
+            # Handle the case where the tensor name is
+            # "base_model.model.lm_head.weight"
+            return ".".join(parts[2:]), True
 
     if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
         return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"

+ 16 - 11
aphrodite/lora/worker_manager.py

@@ -1,17 +1,14 @@
-import logging
 from abc import ABC, abstractmethod, abstractproperty
 from typing import Any, Dict, List, Optional, Set, Type
 
 import torch
 
+from aphrodite.common.config import LoRAConfig
+from aphrodite.lora.layers import LoRAMapping
 from aphrodite.lora.models import (LoRAModel, LoRAModelManager,
                                    LRUCacheLoRAModelManager,
                                    create_lora_manager)
 from aphrodite.lora.request import LoRARequest
-from aphrodite.lora.layers import LoRAMapping
-from aphrodite.common.config import LoRAConfig
-
-logger = logging.getLogger(__name__)
 
 
 class AbstractWorkerLoRAManager(ABC):
@@ -63,7 +60,6 @@ class AbstractWorkerLoRAManager(ABC):
         ...
 
 
-# pylint: disable=function-redefined
 class WorkerLoRAManager(AbstractWorkerLoRAManager):
     """WorkerLoRAManager that manages LoRA models on the worker side.
 
@@ -91,7 +87,6 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
                          lora_config, device)
 
     @property
-    # pylint: disable=invalid-overridden-method
     def is_enabled(self) -> bool:
         return True
 
@@ -139,8 +134,19 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
 
     def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
         try:
+            model = self._lora_manager.model
+            supported_lora_modules = model.supported_lora_modules
+            packed_modules_mapping = model.packed_modules_mapping
+            expected_lora_modules = []
+            for module in supported_lora_modules:
+                if module in packed_modules_mapping:
+                    expected_lora_modules.extend(
+                        packed_modules_mapping[module])
+                else:
+                    expected_lora_modules.append(module)
             lora = self._lora_model_cls.from_local_checkpoint(
                 lora_request.lora_local_path,
+                expected_lora_modules,
                 lora_model_id=lora_request.lora_int_id,
                 device="cpu",
                 dtype=self.lora_config.lora_dtype,
@@ -157,10 +163,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
                 f"LoRA rank {lora.rank} is greater than max_lora_rank "
                 f"{self.lora_config.max_lora_rank}.")
         if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
-            raise ValueError(
-                f"LoRA added vocab size {lora.extra_vocab_size} is "
-                "greater than lora_extra_vocab_size "
-                f"{self.lora_config.lora_extra_vocab_size}.")
+            raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
+                             f"is greater than lora_extra_vocab_size "
+                             f"{self.lora_config.lora_extra_vocab_size}.")
         return lora
 
     def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:

+ 0 - 4
aphrodite/modeling/__init__.py

@@ -1,11 +1,7 @@
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_random_seed
 
 __all__ = [
-    "InputMetadata",
-    "get_model",
     "SamplingMetadata",
     "set_random_seed",
 ]

+ 25 - 0
aphrodite/modeling/guided_decoding/__init__.py

@@ -0,0 +1,25 @@
+from typing import Optional, Union
+
+from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
+                                                 CompletionRequest)
+from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import (
+    get_lm_format_enforcer_guided_decoding_logits_processor)
+from aphrodite.modeling.guided_decoding.outlines_decoding import (
+    get_outlines_guided_decoding_logits_processor)
+from aphrodite.common.sampling_params import LogitsProcessorFunc
+
+
+async def get_guided_decoding_logits_processor(
+        guided_decoding_backend: str, request: Union[CompletionRequest,
+                                                     ChatCompletionRequest],
+        tokenizer) -> Optional[LogitsProcessorFunc]:
+    if guided_decoding_backend == 'outlines':
+        return await get_outlines_guided_decoding_logits_processor(
+            request, tokenizer)
+    if guided_decoding_backend == 'lm-format-enforcer':
+        return await get_lm_format_enforcer_guided_decoding_logits_processor(
+            request, tokenizer)
+
+    raise ValueError(
+        f"Unknown guided decoding backend '{guided_decoding_backend}'. "
+        "Must be one of 'outlines, 'lm-format-enforcer'")

+ 70 - 0
aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py

@@ -0,0 +1,70 @@
+from functools import lru_cache
+from json import loads as json_loads
+from typing import Optional, Union
+
+from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
+                              RegexParser, StringParser,
+                              TokenEnforcerTokenizerData, UnionParser)
+from pydantic import BaseModel
+from transformers import PreTrainedTokenizerBase
+
+from aphrodite.common.sampling_params import LogitsProcessorFunc
+from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
+                                                 CompletionRequest)
+from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import (  # noqa: E501
+    build_aphrodite_logits_processor,
+    build_aphrodite_token_enforcer_tokenizer_data)
+from aphrodite.modeling.guided_decoding.outlines_decoding import \
+    get_outlines_guided_decoding_logits_processor
+
+
+async def get_lm_format_enforcer_guided_decoding_logits_processor(
+        request: Union[CompletionRequest, ChatCompletionRequest],
+        tokenizer) -> Optional[LogitsProcessorFunc]:
+    """
+    Given an OpenAI-compatible request, check for guided decoding parameters
+    and get the necessary logits processor for the given guide.
+    We cache logit processors by (guide, tokenizer), and on cache hit
+    we make a shallow copy to reuse the same underlying FSM.
+    """
+
+    tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
+        tokenizer)
+    character_level_parser: CharacterLevelParser
+    if request.guided_json:
+        schema = _normalize_json_schema_object(request.guided_json)
+        character_level_parser = JsonSchemaParser(schema)
+    elif request.guided_choice:
+        character_level_parser = UnionParser(
+            [StringParser(choice) for choice in request.guided_choice])
+    elif request.guided_regex:
+        character_level_parser = RegexParser(request.guided_regex)
+    elif request.guided_grammar:
+        # CFG grammar not supported by LMFE, revert to outlines
+        return await get_outlines_guided_decoding_logits_processor(
+            request, tokenizer)
+    elif (request.response_format is not None
+          and request.response_format.type == "json_object"):
+        character_level_parser = JsonSchemaParser(
+            None)  # None means any json object
+    else:
+        return None
+
+    logits_processor = build_aphrodite_logits_processor(
+        tokenizer_data, character_level_parser)
+    return logits_processor
+
+
+def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
+    if isinstance(schema, str):
+        return json_loads(schema)
+    if isinstance(schema, dict):
+        return schema
+    if isinstance(schema, BaseModel):
+        return schema.model_json_schema()
+
+
+@lru_cache
+def _cached_build_aphrodite_token_enforcer_tokenizer_data(
+        tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
+    return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)

+ 70 - 0
aphrodite/modeling/guided_decoding/lm_format_enforcer_logits_processors.py

@@ -0,0 +1,70 @@
+"""The LM Format Enforcer logits processor."""
+import math
+from typing import List, Optional, Union
+
+import torch
+from lmformatenforcer import (CharacterLevelParser, FormatEnforcerAnalyzer,
+                              TokenEnforcer, TokenEnforcerTokenizerData)
+from lmformatenforcer.integrations.transformers import \
+    build_token_enforcer_tokenizer_data
+from transformers import PreTrainedTokenizerBase
+
+import aphrodite
+
+
+class aphroditeLogitsProcessor:
+
+    def __init__(self, token_enforcer: TokenEnforcer, analyze):
+        self.token_enforcer = token_enforcer
+        self.analyzer = FormatEnforcerAnalyzer(
+            token_enforcer) if analyze else None
+        self.mask: Optional[torch.Tensor] = None
+
+    def __call__(self, input_ids: List[int],
+                 scores: torch.Tensor) -> torch.Tensor:
+        token_sequence = input_ids
+        if self.analyzer:
+            self.analyzer.report_raw_logits(token_sequence, scores.tolist())
+        allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
+        if self.mask is not None:
+            self.mask.fill_(-math.inf)
+        else:
+            # We create it here because full_like() also copies the device
+            # and dtype
+            self.mask = torch.full_like(scores, -math.inf)
+        self.mask[allowed_tokens] = 0
+        scores = scores + self.mask
+        return scores
+
+
+def build_aphrodite_token_enforcer_tokenizer_data(
+    tokenizer: Union[aphrodite.LLM, PreTrainedTokenizerBase]
+) -> TokenEnforcerTokenizerData:
+    # There are many classes that can be passed here, this logic should work
+    # on all of them.
+    if hasattr(tokenizer, 'get_tokenizer'):
+        tokenizer = tokenizer.get_tokenizer()
+    if hasattr(tokenizer, 'tokenizer'):
+        tokenizer = tokenizer.tokenizer
+    return build_token_enforcer_tokenizer_data(tokenizer)
+
+
+def build_aphrodite_logits_processor(
+        llm: Union[aphrodite.LLM, PreTrainedTokenizerBase,
+                   TokenEnforcerTokenizerData],
+        character_level_parser: CharacterLevelParser,
+        analyze: bool = False) -> aphroditeLogitsProcessor:
+    """Build the logits processor function that aphrodite will use to filter
+    the tokens generated by the model. The result can be passed in the
+    logits_processor list that is sent to the call or generate() method of
+    models."""
+    if not isinstance(llm, TokenEnforcerTokenizerData):
+        llm = build_aphrodite_token_enforcer_tokenizer_data(llm)
+    token_enforcer = TokenEnforcer(llm, character_level_parser)
+    return aphroditeLogitsProcessor(token_enforcer, analyze)
+
+
+__all__ = [
+    'build_aphrodite_logits_processor',
+    'build_aphrodite_token_enforcer_tokenizer_data'
+]

+ 3 - 3
aphrodite/modeling/outlines_decoding.py → aphrodite/modeling/guided_decoding/outlines_decoding.py

@@ -14,7 +14,7 @@ from aphrodite.endpoints.openai.protocol import (
     ChatCompletionRequest,
     CompletionRequest,
 )
-from aphrodite.modeling.outlines_logits_processors import (
+from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
     CFGLogitsProcessor,
     JSONLogitsProcessor,
     RegexLogitsProcessor,
@@ -58,7 +58,7 @@ pair   : UNESCAPED_STRING ":" value
 global_thread_pool = None  # used for generating logits processor fsm
 
 
-async def get_guided_decoding_logits_processor(
+async def get_outlines_guided_decoding_logits_processor(
         request: Union[CompletionRequest, ChatCompletionRequest],
         tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
     """
@@ -94,7 +94,7 @@ def _get_guide_and_mode(
         json = request.guided_json
         if isinstance(json, dict):
             # turn dict into hashable string
-            json = json_dumps(json, sort_keys=True)
+            json = json_dumps(json)
         elif isinstance(json, BaseModel):
             # use pydantic signature so that different model classes
             # with the same fields will get hashed the same

+ 58 - 45
aphrodite/modeling/outlines_logits_processors.py → aphrodite/modeling/guided_decoding/outlines_logits_processors.py

@@ -13,9 +13,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import copy
 import json
 import math
 from collections import defaultdict
+from functools import lru_cache
 from typing import Callable, DefaultDict, Dict, List, Optional, Union
 
 import torch
@@ -27,49 +29,9 @@ from transformers import PreTrainedTokenizerBase
 
 class BaseLogitsProcessor:
 
-    def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
-        """Adapt Aphrodite's tokenizer to use to compile the FSM.
-
-        The API of Outlines tokenizers is slightly different to that of
-        `transformers`. The decoder of outlines, returns a list whereas
-        the decode of Aphrodite returns a str. To sync the Aphrodite decoder
-        with outline's internal api, the decoder should be adapted. In addition
-        we need to handle the missing spaces to Llama's tokenizer to be
-        able to compile FSMs for this model.
-
-        """
-        if getattr(tokenizer, "_outlines_adapted", False):
-            return tokenizer
-
-        tokenizer.vocabulary = tokenizer.get_vocab()
-        tokenizer.special_tokens = set(tokenizer.all_special_tokens)
-
-        def convert_token_to_string(token: str) -> str:
-            from transformers.file_utils import SPIECE_UNDERLINE
-
-            string = tokenizer.convert_tokens_to_string([token])
-
-            # A hack to handle missing spaces to HF's Llama tokenizers
-            if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
-                return " " + string
-
-            return string
-
-        def change_decoder(
-            decoder: Callable[[List[int]], str]
-        ) -> Callable[[List[int]], List[str]]:
-            """Sync Aphrodite's decoder with the outlines by returning list."""
-
-            def new_decoder(inp_tokens: List[int]) -> List[str]:
-                return [decoder(inp_tokens)]
-
-            return new_decoder
-
-        tokenizer.convert_token_to_string = convert_token_to_string
-        tokenizer.decode = change_decoder(tokenizer.decode)
-        setattr(tokenizer, "_outlines_adapted", True)  # noqa: B010
-
-        return tokenizer
+    def __init__(self):
+        # Child class should use initialize in their init.
+        self.fsm: Union[RegexFSM, CFGFSM]
 
     def init_state(self):
         """Initialize the FSM states."""
@@ -113,7 +75,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
             The model's tokenizer
 
         """
-        tokenizer = self.adapt_tokenizer(tokenizer)
+        tokenizer = _adapt_tokenizer(tokenizer)
         fsm = RegexFSM(regex_string, tokenizer)
         self.fsm = fsm
 
@@ -167,6 +129,57 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
             The model's tokenizer
 
         """
-        tokenizer = self.adapt_tokenizer(tokenizer)
+        tokenizer = _adapt_tokenizer(tokenizer)
         fsm = CFGFSM(cfg, tokenizer)
         self.fsm = fsm
+
+    def init_state(self):
+        """Initialize state with a CFGFSM copy."""
+        super().init_state()
+        self.fsm = self.fsm.copy()
+
+
+@lru_cache
+def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
+    """Adapt Aphrodite's tokenizer to use to compile the FSM.
+    The API of Outlines tokenizers is slightly different to that of
+    `transformers`. The decoder of outlines, returns a list whereas
+    the decode of Aphrodite returns an str. To sync the Aphrodite decoder with
+    outlines internal api, the decoder should be adapted. In addition
+    we need to handle the missing spaces to Llama's tokenizer to be
+    able to compile FSMs for this model.
+    """
+    if getattr(tokenizer, "_outlines_adapted", False):
+        return tokenizer
+
+    tokenizer = copy.deepcopy(tokenizer)
+
+    tokenizer.vocabulary = tokenizer.get_vocab()
+    tokenizer.special_tokens = set(tokenizer.all_special_tokens)
+
+    def convert_token_to_string(token: str) -> str:
+        from transformers.file_utils import SPIECE_UNDERLINE
+
+        string = tokenizer.convert_tokens_to_string([token])
+
+        # A hack to handle missing spaces to HF's Llama tokenizers
+        if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
+            return " " + string
+
+        return string
+
+    def change_decoder(
+        decoder: Callable[[List[int]],
+                          str]) -> Callable[[List[int]], List[str]]:
+        """Sync vLLM's decoder with the outlines by returning list."""
+
+        def new_decoder(inp_tokens: List[int]) -> List[str]:
+            return [decoder(inp_tokens)]
+
+        return new_decoder
+
+    tokenizer.convert_token_to_string = convert_token_to_string
+    tokenizer.decode = change_decoder(tokenizer.decode)
+    setattr(tokenizer, "_outlines_adapted", True)  # noqa: B010
+
+    return tokenizer

+ 153 - 92
aphrodite/modeling/hf_downloader.py

@@ -1,28 +1,46 @@
 """Utilities for downloading and initializing model weights."""
-import filelock
-import glob
 import fnmatch
+import glob
 import json
 import os
 from collections import defaultdict
-from typing import Any, Iterator, List, Optional, Tuple
-from loguru import logger
+from typing import Any, Iterable, Iterator, List, Optional, Tuple
 
-from huggingface_hub import snapshot_download, HfFileSystem
+import filelock
+import huggingface_hub.constants
 import numpy as np
-from safetensors.torch import load_file, save_file, safe_open
 import torch
-from transformers import PretrainedConfig
+from huggingface_hub import HfFileSystem, snapshot_download
+from loguru import logger
+from safetensors.torch import load_file, safe_open, save_file
 from tqdm.auto import tqdm
+from transformers import PretrainedConfig, AutoModelForCausalLM
 
 from aphrodite.common.config import ModelConfig
+from aphrodite.quantization.gguf_utils import (GGUFReader, get_tensor_name_map,
+                                               MODEL_ARCH_NAMES)
 from aphrodite.common.logger import get_loading_progress_bar
-from aphrodite.common.gguf import GGUFReader
-from aphrodite.modeling.layers.quantization import (get_quantization_config,
-                                                    QuantizationConfig)
+from aphrodite.quantization import (QuantizationConfig,
+                                    get_quantization_config)
+from aphrodite.quantization.schema import QuantParamSchema
+
+_xdg_cache_home = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+_aphrodite_filelocks_path = os.path.join(_xdg_cache_home, "aphrodite/locks/")
+
+
+def enable_hf_transfer():
+    """automatically activates hf_transfer
+    """
+    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
+        try:
+            # enable hf hub transfer if available
+            import hf_transfer  # type: ignore # noqa
+            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
+        except ImportError:
+            pass
+
 
-_xdg_cache_home = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
-_aphrodite_filelocks_path = os.path.join(_xdg_cache_home, 'aphrodite/locks/')
+enable_hf_transfer()
 
 
 class Disabledtqdm(tqdm):
@@ -35,7 +53,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
     lock_dir = cache_dir if cache_dir is not None else _aphrodite_filelocks_path
     os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
     lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
-    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
+    lock = filelock.SoftFileLock(os.path.join(lock_dir, lock_file_name))
     return lock
 
 
@@ -103,11 +121,13 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
     if not is_local:
         # Download the config files.
         with get_lock(model_name_or_path, model_config.download_dir):
-            hf_folder = snapshot_download(model_name_or_path,
-                                          revision=model_config.revision,
-                                          allow_patterns="*.json",
-                                          cache_dir=model_config.download_dir,
-                                          tqdm_class=Disabledtqdm)
+            hf_folder = snapshot_download(
+                model_name_or_path,
+                revision=model_config.revision,
+                allow_patterns="*.json",
+                cache_dir=model_config.download_dir,
+                tqdm_class=Disabledtqdm,
+            )
     else:
         hf_folder = model_name_or_path
     config_files = glob.glob(os.path.join(hf_folder, "*.json"))
@@ -172,11 +192,13 @@ def prepare_hf_model_weights(
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
         with get_lock(model_name_or_path, cache_dir):
-            hf_folder = snapshot_download(model_name_or_path,
-                                          allow_patterns=allow_patterns,
-                                          cache_dir=cache_dir,
-                                          tqdm_class=Disabledtqdm,
-                                          revision=revision)
+            hf_folder = snapshot_download(
+                model_name_or_path,
+                allow_patterns=allow_patterns,
+                cache_dir=cache_dir,
+                tqdm_class=Disabledtqdm,
+                revision=revision,
+            )
     else:
         hf_folder = model_name_or_path
     hf_weights_files: List[str] = []
@@ -211,78 +233,76 @@ def prepare_hf_model_weights(
 
 
 def convert_gguf_to_state_dict(checkpoint, config):
-    if not os.path.isfile(checkpoint):
+    model_type = config.model_type
+    # hack: ggufs have a different name than transformers
+    if model_type == "cohere":
+        model_type = "command-r"
+    arch = None
+    for key, value in MODEL_ARCH_NAMES.items():
+        if value == model_type:
+            arch = key
+            break
+    if arch is None:
+        raise RuntimeError(f"Unknown model_type: {model_type}")
+    num_layers = config.num_hidden_layers
+    name_map = get_tensor_name_map(arch, num_layers)
+    with torch.device("meta"):
+        dummy_model = AutoModelForCausalLM.from_config(config)
+    state_dict = dummy_model.state_dict()
+
+    gguf_to_hf_name_map = {}
+    keys_to_remove = []
+    for hf_name in state_dict:
+        name, suffix = hf_name.rsplit(".", 1)
+        gguf_name = name_map.get_name(name)
+        if gguf_name:
+            gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
+        elif name == "lm_head":
+            keys_to_remove.append(hf_name)
+            logger.warning(
+                f"GGUF tensor name for {hf_name} not found, "
+                "this is normal if the model uses tie word embeddings.")
+        else:
+            logger.warning(
+                f"GGUF tensor name for {hf_name} in hf state_dict not found.")
+    for key in keys_to_remove:
+        state_dict.pop(key)
+
+    if os.path.isfile(checkpoint):
+        results = [GGUFReader(checkpoint)]
+    elif os.path.isdir(checkpoint):
+        results = [
+            GGUFReader(os.path.join(checkpoint, file))
+            for file in os.listdir(checkpoint)
+            if os.path.splitext(file)[-1].lower() == ".gguf"
+        ]
+    else:
         raise RuntimeError(
             f"Cannot find any model weights with `{checkpoint}`")
 
-    result = GGUFReader(checkpoint)
-    # write tensor
-    kv_dim = (config.hidden_size // config.num_attention_heads *
-              config.num_key_value_heads)
-    tensor_mapping = {
-        "token_embd": ("model.embed_tokens", config.vocab_size),
-        "output": ("lm_head", config.vocab_size),
-        "output_norm": ("model.norm", -1),
-        "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
-        "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj",
-                             config.hidden_size),
-        "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
-        "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
-        "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj",
-                                  config.hidden_size),
-        "blk.{bid}.attn_rot_embd":
-        ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
-        "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm",
-                               -1),
-        "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj",
-                             config.intermediate_size),
-        "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj",
-                               config.hidden_size),
-        "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj",
-                               config.intermediate_size),
-        "blk.{bid}.ffn_up.{xid}":
-        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3",
-         config.intermediate_size),
-        "blk.{bid}.ffn_down.{xid}":
-        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2",
-         config.hidden_size),
-        "blk.{bid}.ffn_gate.{xid}":
-        ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1",
-         config.intermediate_size),
-        "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate",
-                                   config.num_local_experts if hasattr(
-                                       config, "num_local_experts") else -1),
-    }
-    mapping = {}
-    # This is how llama.cpp handles name mapping,
-    # it's better to use regex match instead doe
-    max_block_num = 200
-    max_expert_num = 8
-    for k, v in tensor_mapping.items():
-        for i in range(max_block_num):
-            for j in range(max_expert_num):
-                fk = k.format(bid=i, xid=j)
-                fv = v[0].format(bid=i, xid=j)
-                if k not in mapping:
-                    mapping[fk] = (fv, v[1])
-
-    state_dict = {}
     with get_loading_progress_bar() as progress:
-        task = progress.add_task("[cyan]Converting GGUF tensors to PyTorch...",
-                                 total=len(result.tensors))
-        for ts in result.tensors:
-            weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
-            layer, suffix = ts.name.rsplit(".", 1)
-            new_key, output_dim = mapping[layer]
-            new_key += f".{suffix}"
-            data = torch.tensor(ts.data)
-            if output_dim != -1:
-                data = data.view(output_dim, -1)
-            if weight_type > 1:
-                state_dict[new_key.replace("weight",
-                                           "weight_type")] = weight_type
-            state_dict[new_key] = data
-            progress.update(task, advance=1)
+        task = progress.add_task(
+            "[cyan]Converting GGUF tensors to PyTorch...",
+            total=sum([len(result.tensors) for result in results]),
+        )
+        for result in results:
+            for ts in result.tensors:
+                try:
+                    hf_name = gguf_to_hf_name_map[ts.name]
+                except KeyError:
+                    logger.warning(
+                        f"hf tensor name for {ts.name} in GGUF not found.")
+                    continue
+                data = torch.tensor(ts.data)
+                if state_dict[hf_name].dim() == 2:
+                    data = data.view(state_dict[hf_name].shape[0], -1)
+                state_dict[hf_name] = data
+                weight_type = torch.tensor(int(ts.tensor_type),
+                                           dtype=torch.int)
+                if weight_type > 1:
+                    state_dict[hf_name.replace("weight",
+                                               "weight_type")] = weight_type
+                progress.update(task, advance=1)
     return state_dict
 
 
@@ -305,7 +325,8 @@ def hf_model_weights_iterator(
         cache_dir=cache_dir,
         load_format=load_format,
         fall_back_to_pt=fall_back_to_pt,
-        revision=revision)
+        revision=revision,
+    )
 
     if load_format == "npcache":
         # Currently np_cache only support *.bin checkpoints
@@ -354,6 +375,46 @@ def hf_model_weights_iterator(
             torch.cuda.empty_cache()
 
 
+def kv_cache_scales_loader(
+        filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
+        model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
+    """
+    A simple utility to read in KV cache scaling factors that have been
+    previously serialized to disk. Used by the model to populate the appropriate
+    KV cache scaling factors. The serialization should represent a dictionary
+    whose keys are the TP ranks and values are another dictionary mapping layers
+    to their KV cache scaling factors.
+    Keep this function in sync with the output of examples/fp8/extract_scales.py
+    """
+    try:
+        with open(filename) as f:
+            context = {
+                "model_type": model_type,
+                "num_hidden_layers": num_hidden_layers,
+                "tp_rank": tp_rank,
+                "tp_size": tp_size,
+            }
+            schema_dct = json.load(f)
+            schema = QuantParamSchema.model_validate(schema_dct,
+                                                     context=context)
+            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
+            return layer_scales_map.items()
+
+    except FileNotFoundError:
+        logger.error(f"File or directory '{filename}' not found.")
+    except json.JSONDecodeError:
+        logger.error(f"Error decoding JSON in file '{filename}'.")
+    except Exception as e:
+        logger.error(f"An error occurred while reading '{filename}': {e}")
+    # This section is reached if and only if any of the excepts are hit
+    # Return an empty iterable (list) => no KV cache scales are loaded
+    # which ultimately defaults to 1.0 scales
+    logger.warning("Defaulting to KV cache scaling factors = 1.0 "
+                   f"for all layers in TP rank {tp_rank} "
+                   "as an error occurred during loading.")
+    return []
+
+
 def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
     """convert PySafeSlice object from safetensors to torch.Tensor
 

+ 6 - 4
aphrodite/modeling/layers/activation.py

@@ -7,10 +7,12 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 from aphrodite._C import ops
-from aphrodite.modeling.layers.quantization import QuantizationConfig
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.utils import divide
+from aphrodite.distributed import (
+    divide,
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
+from aphrodite.quantization import QuantizationConfig
 from aphrodite.modeling.utils import set_weight_attrs
 
 

+ 0 - 93
aphrodite/modeling/layers/attention/__init__.py

@@ -1,93 +0,0 @@
-"""Attention layer."""
-from functools import lru_cache
-from typing import List, Optional
-
-import torch
-import torch.nn as nn
-from loguru import logger
-
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.common.utils import is_hip
-
-
-class Attention(nn.Module):
-    """Attention layer.
-
-    This class takes query, key, and value tensors as input. The input tensors
-    can either contain prompt tokens or generation tokens.
-    The class does the following:
-
-    1. Store the input key and value tensors in the KV cache.
-    2. Perform (multi-head/multi-query/grouped-query) attention.
-    3. Return the output tensor.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        head_size: int,
-        scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
-        super().__init__()
-        if _use_flash_attn():
-            from aphrodite.modeling.layers.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
-
-            self.backend = FlashAttentionBackend(
-                num_heads,
-                head_size,
-                scale,
-                num_kv_heads,
-                alibi_slopes,
-                sliding_window,
-            )
-        else:
-            from aphrodite.modeling.layers.attention.backends.xformers import XFormersBackend  # noqa: E501
-
-            self.backend = XFormersBackend(
-                num_heads,
-                head_size,
-                scale,
-                num_kv_heads,
-                alibi_slopes,
-                sliding_window,
-            )
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: Optional[torch.Tensor],
-        value_cache: Optional[torch.Tensor],
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        return self.backend.forward(query, key, value, key_cache, value_cache,
-                                    input_metadata)
-
-
-@lru_cache(maxsize=1)
-def _use_flash_attn() -> bool:
-    try:
-        import flash_attn  # noqa: F401
-    except ImportError:
-        logger.info("flash_attn is not found. Using xformers backend.")
-        return False
-
-    if is_hip():
-        # AMD GPUs.
-        return False
-    if torch.cuda.get_device_capability()[0] < 8:
-        logger.info("flash_attn is not supported on Turing or older GPUs. "
-                    "Using xformers backend.")
-        return False
-    if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
-        logger.info(
-            "flash_attn only supports torch.float16 or torch.bfloat16. "
-            "Using xformers backend.")
-        return False
-
-    logger.info("Using Flash Attention backend.")
-    return True

+ 0 - 138
aphrodite/modeling/layers/attention/backends/flash_attn.py

@@ -1,138 +0,0 @@
-"""Attention layer with Flash and PagedAttention."""
-from typing import List, Optional
-
-from flash_attn import flash_attn_varlen_func
-import torch
-
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention.ops.paged_attn import (
-    PagedAttentionImpl)
-
-
-class FlashAttentionBackend:
-    """
-    If the input tensors contain prompt tokens, the layout is as follows:
-    |<--------------- num_prompt_tokens -------------->|	
-    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
-
-    Otherwise, the layout is as follows:	
-    |<------------------ num_generation_tokens (M) ----------------->|	
-    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
-
-    Generation tokens can contain padding when cuda-graph is used.
-    Currently, prompt tokens don't contain any padding.
-    The prompts might have different lengths, while the generation tokens
-    always have length 1.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        head_size: int,
-        scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
-        self.num_heads = num_heads
-        self.head_size = head_size
-        self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
-        if alibi_slopes is not None:
-            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
-        self.alibi_slopes = alibi_slopes
-
-        assert self.num_heads % self.num_kv_heads == 0
-        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
-            raise ValueError(
-                f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
-
-        self.sliding_window = ((self.sliding_window, self.sliding_window) if
-                               self.sliding_window is not None else (-1, -1))
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: Optional[torch.Tensor],
-        value_cache: Optional[torch.Tensor],
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        """Forward pass with FlashAttention and PagedAttention.
-
-        Args:
-            query: shape = [num_tokens, num_heads * head_size]
-            key: shape = [num_tokens, num_kv_heads * head_size]
-            value: shape = [num_tokens, num_kv_heads * head_size]
-            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
-                block_size, x]
-            value_cache: shape = [num_blocks, num_kv_heads, head_size,
-                block_size]
-            input_metadata: metadata for the inputs.
-        Returns:
-            shape = [num_tokens, num_heads * head_size]
-        """
-        num_tokens, hidden_size = query.shape
-        # Reshape the query, key, and value tensors.
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-
-        # Reshape the keys and values and store them in the cache.
-        # If key_cache and value_cache are not provided, the new key and value
-        # vectors will not be cached. This happens during the initial memory
-        # profiling run.
-        if key_cache is not None and value_cache is not None:
-            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
-                                                 value_cache, input_metadata)
-
-        if input_metadata.is_prompt:
-            # Prompt run.
-            if (key_cache is None or value_cache is None
-                    or input_metadata.block_tables.numel() == 0):
-                # normal attention
-                # When block_tables are not filled, it means q and k are the
-                # prompt, and they have the same length.
-                output = flash_attn_varlen_func(
-                    q=query,
-                    k=key,
-                    v=value,
-                    cu_seqlens_q=input_metadata.seq_start_loc,
-                    cu_seqlens_k=input_metadata.seq_start_loc,
-                    max_seqlen_q=input_metadata.max_seq_len,
-                    max_seqlen_k=input_metadata.max_seq_len,
-                    softmax_scale=self.scale,
-                    causal=True,
-                    window_size=self.sliding_window,
-                    alibi_slopes=self.alibi_slopes,
-                )
-            else:
-                # prefix-enabled attention
-                output = PagedAttentionImpl.forward_prefix(
-                    query,
-                    key,
-                    value,
-                    key_cache,
-                    value_cache,
-                    input_metadata,
-                    self.alibi_slopes,
-                )
-        else:
-            # Decoding run.
-            output = PagedAttentionImpl.forward_decode(
-                query,
-                key_cache,
-                value_cache,
-                input_metadata,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-            )
-
-        # Reshape the output tensor.
-        return output.view(num_tokens, hidden_size)

+ 0 - 315
aphrodite/modeling/layers/attention/backends/xformers.py

@@ -1,315 +0,0 @@
-"""Attention layer with xFormers and PagedAttention."""
-import importlib
-from typing import List, Optional
-
-from loguru import logger
-import torch
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
-                                         LowerTriangularMaskWithTensorBias)
-
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention.ops.paged_attn import (
-    PagedAttentionImpl)
-from aphrodite.common.utils import is_hip
-
-
-class XFormersBackend:
-    """
-    If the input tensors contain prompt tokens, the layout is as follows:
-    |<--------------- num_prompt_tokens --------------->|	
-    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
-
-    Otherwise, the layout is as follows:	
-    |<------------------ num_generation_tokens (M) ----------------->|	
-    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
-
-    Generation tokens can contain padding when cuda-graph is used.
-    Currently, prompt tokens don't contain any padding.
-    The prompts might have different lengths, while the generation tokens
-    always have length 1.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        head_size: int,
-        scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
-        self.num_heads = num_heads
-        self.head_size = head_size
-        self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
-        if alibi_slopes is not None:
-            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
-        self.alibi_slopes = alibi_slopes
-
-        assert self.num_heads % self.num_kv_heads == 0
-        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
-            raise ValueError(
-                f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
-
-        self.use_ref_attention = _check_use_ref_attention()
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: Optional[torch.Tensor],
-        value_cache: Optional[torch.Tensor],
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        """Forward pass with xFormers and PagedAttention.
-
-        Args:
-            query: shape = [num_tokens, num_heads * head_size]
-            key: shape = [num_tokens, num_kv_heads * head_size]
-            value: shape = [num_tokens, num_kv_heads * head_size]
-            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
-                block_size, x]
-            value_cache: shape = [num_blocks, num_kv_heads, head_size,
-                block_size]
-            input_metadata: metadata for the inputs.
-        Returns:
-            shape = [num_tokens, num_heads * head_size]
-        """
-        num_tokens, hidden_size = query.shape
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-
-        # Reshape the keys and values and store them in the cache.
-        # If key_cache and value_cache are not provided, the new key and value
-        # vectors will not be cached. This happens during the initial memory
-        # profiling run.
-        if key_cache is not None and value_cache is not None:
-            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
-                                                 value_cache, input_metadata)
-
-        if input_metadata.is_prompt:
-            # Prompt run.
-            # key_cache and value_cache are None when it is a profiling run.
-            # block tables are empty if the prompt has never been computed.
-            if (key_cache is None or value_cache is None
-                    or input_metadata.block_tables.numel() == 0):
-                if self.num_kv_heads != self.num_heads:
-                    # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
-                    # project the key and value tensors to the desired number of
-                    # heads.
-                    # TODO: Use MQA/GQA kernels for higher performance.
-                    query = query.view(query.shape[0], self.num_kv_heads,
-                                       self.num_queries_per_kv,
-                                       query.shape[-1])
-                    key = key[:, :,
-                              None, :].expand(key.shape[0], self.num_kv_heads,
-                                              self.num_queries_per_kv,
-                                              key.shape[-1])
-                    value = value[:, :,
-                                  None, :].expand(value.shape[0],
-                                                  self.num_kv_heads,
-                                                  self.num_queries_per_kv,
-                                                  value.shape[-1])
-
-                if self.use_ref_attention:
-                    logger.debug("ref attention used.")
-                    output = torch.empty_like(query)
-                    start = 0
-                    for _, prompt_len in enumerate(input_metadata.prompt_lens):
-                        end = start + prompt_len
-                        out = _ref_masked_attention(
-                            query[None, start:end],
-                            key[None, start:end],
-                            value[None, start:end],
-                            self.num_heads,
-                            self.num_kv_heads,
-                            self.head_size,
-                            self.scale,
-                        )
-                        # TODO: Unnecessary copy. Optimize.
-                        output[start:end].copy_(out)
-                        start += prompt_len
-                    # Using view got RuntimeError: view size is not compatible
-                    # with input tensor's size and stride (at least one
-                    # dimension spans across two contiguous subspaces).
-                    # Use reshape instead.
-                    return output.reshape(num_tokens, hidden_size)
-                output = self._run_memory_efficient_xformer_forward(
-                    query, key, value, input_metadata)
-            else:
-                # prefix-enabled attention
-                output = PagedAttentionImpl.forward_prefix(
-                    query,
-                    key,
-                    value,
-                    key_cache,
-                    value_cache,
-                    input_metadata,
-                    self.alibi_slopes,
-                )
-        else:
-            # Decoding run.
-            output = PagedAttentionImpl.forward_decode(
-                query,
-                key_cache,
-                value_cache,
-                input_metadata,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-            )
-
-        # Reshape the output tensor.
-        return output.view(-1, self.num_heads * self.head_size)
-
-    def _run_memory_efficient_xformer_forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        """Attention for 1D query of multiple prompts. Multiple prompt
-        tokens are flattened in to `query` input.
-        Args:
-            output: shape = [num_prompt_tokens, num_heads, head_size]
-            query: shape = [num_prompt_tokens, num_heads, head_size]
-            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
-            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
-            input_metadata: metadata for paged attention.
-        """
-        # Set attention bias if not provided. This typically happens at
-        # the very attention layer of every iteration.
-        # FIXME: This is a hack.
-        if input_metadata.attn_bias is None:
-            if self.alibi_slopes is None:
-                attn_bias = BlockDiagonalCausalMask.from_seqlens(
-                    input_metadata.prompt_lens)
-                if self.sliding_window is not None:
-                    attn_bias = attn_bias.make_local_attention(
-                        self.sliding_window)
-                input_metadata.attn_bias = [attn_bias]
-            else:
-                input_metadata.attn_bias = _make_alibi_bias(
-                    self.alibi_slopes, self.num_kv_heads, query.dtype,
-                    input_metadata)
-
-        op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
-            is_hip()) else None
-        # No alibi slopes.
-        # TODO: Too many view operations. Let's try to reduce
-        # them in the future for code readability.
-        if self.alibi_slopes is None:
-            query = query.unsqueeze(0)
-            key = key.unsqueeze(0)
-            value = value.unsqueeze(0)
-            out = xops.memory_efficient_attention_forward(
-                query,
-                key,
-                value,
-                attn_bias=input_metadata.attn_bias[0],
-                p=0.0,
-                scale=self.scale,
-                op=op)
-
-            return out.view_as(query)
-
-        # Attention with alibi slopes.
-        # FIXME: Because xformers does not support dynamic sequence
-        # lengths with custom attention bias, we process each prompt one by
-        # one. This is inefficient, especially when we have many short prompts.
-        output = torch.empty_like(query)
-        start = 0
-        for i, prompt_len in enumerate(input_metadata.prompt_lens):
-            end = start + prompt_len
-            out = xops.memory_efficient_attention_forward(
-                query[None, start:end],
-                key[None, start:end],
-                value[None, start:end],
-                attn_bias=input_metadata.attn_bias[i],
-                p=0.0,
-                scale=self.scale,
-                op=op)
-            # TODO(woosuk): Unnecessary copy. Optimize.
-            output[start:end].copy_(out.squeeze(0))
-            start += prompt_len
-        return output
-
-
-def _make_alibi_bias(
-    alibi_slopes: torch.Tensor,
-    num_kv_heads: int,
-    dtype: torch.dtype,
-    input_metadata: InputMetadata,
-) -> LowerTriangularMaskWithTensorBias:
-    attn_biases = []
-    for prompt_len in input_metadata.prompt_lens:
-        bias = torch.arange(prompt_len, dtype=dtype)
-        # NOTE: HF uses
-        #     `bias = bias[None, :].repeat(prompt_len, 1)`
-        # here. We find that both biases give the same results, but
-        # the bias below more accurately follows the original ALiBi
-        # paper.
-        # Calculate a matrix where each element represents ith element- jth
-        # element.
-        bias = bias[None, :] - bias[:, None]
-
-        padded_len = (prompt_len + 7) // 8 * 8
-        num_heads = alibi_slopes.shape[0]
-        bias = torch.empty(
-            1,  # batch size
-            num_heads,
-            prompt_len,
-            padded_len,
-            device=alibi_slopes.device,
-            dtype=dtype,
-        )[:, :, :, :prompt_len].copy_(bias)
-        bias.mul_(alibi_slopes[:, None, None])
-        if num_heads != num_kv_heads:
-            bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
-        attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
-
-    return attn_biases
-
-
-def _check_use_ref_attention() -> bool:
-    if not is_hip():
-        return False
-    # For ROCm, check whether flash attention is installed or not.
-    # if not, use_ref_attention needs to be True
-    return importlib.util.find_spec("flash_attn") is None
-
-
-def _ref_masked_attention(
-    query: torch.Tensor,
-    key: torch.Tensor,
-    value: torch.Tensor,
-    num_heads: int,
-    num_kv_heads: int,
-    head_size: int,
-    scale: float,
-) -> torch.Tensor:
-    query = query.view(-1, num_heads, head_size)
-    key = key.view(-1, num_kv_heads, head_size)
-    value = value.view(-1, num_kv_heads, head_size)
-
-    seq_len, _, _ = query.shape
-    attn_mask = torch.triu(torch.ones(seq_len,
-                                      seq_len,
-                                      dtype=query.dtype,
-                                      device=query.device),
-                           diagonal=1)
-    attn_mask = attn_mask * torch.finfo(query.dtype).min
-
-    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
-    attn_weights = attn_weights + attn_mask.float()
-    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
-    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
-    return out

+ 0 - 139
aphrodite/modeling/layers/attention/ops/paged_attn.py

@@ -1,139 +0,0 @@
-from typing import List, Optional
-
-import torch
-
-from aphrodite._C import cache_ops
-from aphrodite._C import ops
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention.ops.prefix_prefill import (
-    context_attention_fwd)
-
-# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
-_PARTITION_SIZE = 512
-
-
-class PagedAttentionImpl:
-
-    @staticmethod
-    def get_supported_head_sizes() -> List[int]:
-        return [64, 80, 96, 112, 128, 256]
-
-    @staticmethod
-    def reshape_and_cache(
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: torch.Tensor,
-        value_cache: torch.Tensor,
-        input_metadata: InputMetadata,
-    ) -> None:
-        cache_ops.reshape_and_cache(
-            key,
-            value,
-            key_cache,
-            value_cache,
-            input_metadata.slot_mapping.flatten(),
-            input_metadata.kv_cache_dtype,
-        )
-
-    @staticmethod
-    def forward_decode(
-        query: torch.Tensor,
-        key_cache: torch.Tensor,
-        value_cache: torch.Tensor,
-        input_metadata: InputMetadata,
-        num_kv_heads: int,
-        scale: float,
-        alibi_slopes: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        output = torch.empty_like(query)
-
-        block_size = value_cache.shape[3]
-        num_seqs, num_heads, head_size = query.shape
-        max_num_partitions = (
-            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
-            _PARTITION_SIZE)
-        # NOTE: We use a simple heuristic to decide whether to use
-        # PagedAttention V1 or V2. If the number of partitions is 1, we use
-        # V1 to avoid the overhead of reduction. Also, if the number of
-        # sequences or heads is large, we use V1 since there is enough work
-        # to parallelize.
-        # TODO: Tune this heuristic.
-        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
-        use_v1 = input_metadata.max_context_len <= 8192 and (
-            max_num_partitions == 1 or num_seqs * num_heads > 512)
-        if use_v1:
-            # Run PagedAttention V1.
-            ops.paged_attention_v1(
-                output,
-                query,
-                key_cache,
-                value_cache,
-                num_kv_heads,
-                scale,
-                input_metadata.block_tables,
-                input_metadata.context_lens,
-                block_size,
-                input_metadata.max_context_len,
-                alibi_slopes,
-                input_metadata.kv_cache_dtype,
-            )
-        else:
-            # Run PagedAttention V2.
-            assert _PARTITION_SIZE % block_size == 0
-            tmp_output = torch.empty(
-                size=(num_seqs, num_heads, max_num_partitions, head_size),
-                dtype=output.dtype,
-                device=output.device,
-            )
-            exp_sums = torch.empty(
-                size=(num_seqs, num_heads, max_num_partitions),
-                dtype=torch.float32,
-                device=output.device,
-            )
-            max_logits = torch.empty_like(exp_sums)
-            ops.paged_attention_v2(
-                output,
-                exp_sums,
-                max_logits,
-                tmp_output,
-                query,
-                key_cache,
-                value_cache,
-                num_kv_heads,
-                scale,
-                input_metadata.block_tables,
-                input_metadata.context_lens,
-                block_size,
-                input_metadata.max_context_len,
-                alibi_slopes,
-                input_metadata.kv_cache_dtype,
-            )
-        return output
-
-    @staticmethod
-    def forward_prefix(
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: torch.Tensor,
-        value_cache: torch.Tensor,
-        input_metadata: InputMetadata,
-        alibi_slopes: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        output = torch.empty_like(query)
-        context_attention_fwd(
-            query,
-            key,
-            value,
-            output,
-            key_cache,
-            value_cache,
-            input_metadata.block_tables,
-            # subquery_start_loc is (batch_size + 1,)
-            input_metadata.subquery_start_loc[:-1],
-            input_metadata.prompt_lens_tensor,
-            input_metadata.context_lens,
-            input_metadata.max_subquery_len,
-            alibi_slopes,
-        )
-        return output

+ 3 - 5
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -1,8 +1,6 @@
-from aphrodite.modeling.layers.fused_moe.fused_moe import (fused_moe,
-                                                           get_config_file_name
-                                                           )
+from aphrodite.modeling.layers.fused_moe.fused_moe import (
+    fused_moe, moe_align_block_size, fused_topk, get_config_file_name)
 
 __all__ = [
-    "fused_moe",
-    "get_config_file_name",
+    "fused_moe", "moe_align_block_size", "fused_topk", "get_config_file_name"
 ]

+ 49 - 33
aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -243,6 +243,53 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
     )
 
 
+def fused_topk(
+    gating_output: torch.Tensor,
+    topk: int,
+    renormalize: bool,
+):
+    """Compute top-k indice and weights from gating logits
+
+    Args:
+        gating_output (torch.Tensor): The output of the gating operation
+          (before softmax).
+        topk (int): The number of top-k experts to select.
+        renormalize (bool): If True, renormalize the top-k weights to sum to 1.
+    """
+    M = gating_output.shape[0]
+    if is_hip():
+        # The MoE kernels are not yet supported on ROCm.
+        routing_weights = torch.softmax(gating_output,
+                                        dim=-1,
+                                        dtype=torch.float32)
+        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
+    else:
+        import aphrodite._moe_C as moe_kernels
+
+        topk_weights = torch.empty(M,
+                                   topk,
+                                   dtype=torch.float32,
+                                   device=gating_output.device)
+        topk_ids = torch.empty(M,
+                               topk,
+                               dtype=torch.int32,
+                               device=gating_output.device)
+        token_expert_indicies = torch.empty(M,
+                                            topk,
+                                            dtype=torch.int32,
+                                            device=gating_output.device)
+        moe_kernels.topk_softmax(
+            topk_weights,
+            topk_ids,
+            token_expert_indicies,
+            gating_output.float(),  # TODO(woosuk): Optimize this.
+        )
+        del token_expert_indicies  # Not used. Will be used in the future.
+    if renormalize:
+        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+    return topk_weights, topk_ids
+
+
 def get_config_file_name(E: int, N: int) -> str:
     device_name = torch.cuda.get_device_name().replace(" ", "_")
     return f"E={E},N={N},device_name={device_name}.json"
@@ -284,7 +331,7 @@ def fused_moe(
     gating_output: torch.Tensor,
     topk: int,
     renormalize: bool,
-    inplace: bool = False,
+    inplace: bool = True,
     override_config: Optional[Dict[str, Any]] = None,
 ) -> torch.Tensor:
     """
@@ -313,44 +360,13 @@ def fused_moe(
     assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
     assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
     assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
-    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
-    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
     assert hidden_states.dtype in [
         torch.float32, torch.float16, torch.bfloat16
     ]
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
 
-    if is_hip():
-        # The MoE kernels are not yet supported on ROCm.
-        routing_weights = torch.softmax(gating_output,
-                                        dim=-1,
-                                        dtype=torch.float32)
-        topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
-    else:
-        import aphrodite._moe_C as moe_kernels
-
-        topk_weights = torch.empty(M,
-                                   topk,
-                                   dtype=torch.float32,
-                                   device=hidden_states.device)
-        topk_ids = torch.empty(M,
-                               topk,
-                               dtype=torch.int32,
-                               device=hidden_states.device)
-        token_expert_indicies = torch.empty(M,
-                                            topk,
-                                            dtype=torch.int32,
-                                            device=hidden_states.device)
-        moe_kernels.topk_softmax(
-            topk_weights,
-            topk_ids,
-            token_expert_indicies,
-            gating_output.float(),  # TODO: Optimize this.
-        )
-        del token_expert_indicies  # Not used. Will be used in the future.
-    if renormalize:
-        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+    topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
 
     if override_config:
         config = override_config

+ 190 - 27
aphrodite/modeling/layers/linear.py

@@ -1,17 +1,20 @@
 from abc import ABC, abstractmethod
 from typing import Any, Dict, List, Optional
-from loguru import logger
 
 import torch
 import torch.nn.functional as F
+from loguru import logger
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
-from aphrodite.modeling.megatron.utils import (divide,
-                                               split_tensor_along_last_dim)
+from aphrodite.distributed import (
+    divide,
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+    split_tensor_along_last_dim,
+    tensor_model_parallel_all_gather,
+    tensor_model_parallel_all_reduce,
+)
+from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.utils import set_weight_attrs
 
 
@@ -23,6 +26,70 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
     return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
 
 
+# Split the exl2 weight at group boundary
+def post_init_exl2(layer):
+    q_groups = layer.q_groups
+    num_groups = q_groups.shape[0] // 2
+    input_size = layer.q_invperm.shape[0]
+    tp_rank = get_tensor_model_parallel_rank()
+    tp_size = get_tensor_model_parallel_world_size()
+    rows = 0
+    # row_number, qrow_number, group_number
+    splits = [(0, 0, 0)]
+    index = 1
+    for i in range(num_groups - 1):
+        bits = q_groups[i * 2].item()
+        qrows = (q_groups[i * 2 + 3] - q_groups[i * 2 + 1]).item()
+        rows += qrows * 32 // bits
+
+        q_rows_end = q_groups[i * 2 + 3].item()
+        if q_rows_end >= layer.q_weight.shape[0] // tp_size * index:
+            splits.append((rows, q_rows_end, i + 1))
+            index += 1
+    splits.append((input_size, layer.q_weight.shape[0], num_groups))
+
+    shard_qweight = torch.nn.Parameter(
+        layer.q_weight[splits[tp_rank][1]:splits[tp_rank + 1][1]].clone(),
+        requires_grad=False,
+    )
+    # the outer model.named_parameters() during loading still keeps a reference
+    # so we have to force release memory
+    layer.q_weight.data = torch.empty(0, device="cpu")
+    del layer.q_weight
+    layer.linear_weights["q_weight"] = shard_qweight.to(layer.device)
+
+    shard_qgroups = torch.nn.Parameter(
+        layer.q_groups[splits[tp_rank][2] * 2:splits[tp_rank + 1][2] *
+                       2].clone(),
+        requires_grad=False,
+    )
+    shard_qgroups[1::2] -= int(shard_qgroups[1])
+    layer.q_groups.data = torch.empty(0, device="cpu")
+    del layer.q_groups
+    layer.linear_weights["q_groups"] = shard_qgroups.to(layer.device)
+
+    shard_qscale = torch.nn.Parameter(
+        layer.q_scale[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
+        requires_grad=False,
+    )
+    layer.q_scale.data = torch.empty(0, device="cpu")
+    del layer.q_scale
+    layer.linear_weights["q_scale"] = shard_qscale.to(layer.device)
+
+    shard_qscale_max = torch.nn.Parameter(
+        layer.q_scale_max[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
+        requires_grad=False,
+    )
+    layer.q_scale_max.data = torch.empty(0, device="cpu")
+    del layer.q_scale_max
+    layer.linear_weights["q_scale_max"] = shard_qscale_max.to(layer.device)
+
+    q_perm = torch.argsort(layer.q_invperm).to(torch.short)
+    # q_invperm is not used later, probably we should remove it too
+    layer.linear_weights["q_perm"] = q_perm[
+        splits[tp_rank][0]:splits[tp_rank + 1][0]].to(layer.device)
+
+
 class LinearMethodBase(ABC):
     """Base class for different (maybe quantized) linear methods."""
 
@@ -42,6 +109,36 @@ class LinearMethodBase(ABC):
         """Apply the weights to the input tensor."""
         raise NotImplementedError
 
+    def create_moe_weights(self, num_experts: int,
+                           input_size_per_partition: int,
+                           output_partition_sizes: List[int], input_size: int,
+                           output_size: int,
+                           params_dtype: torch.dtype) -> Dict[str, Any]:
+        """Creating moe weights"""
+        linear_weights = self.create_weights(input_size_per_partition,
+                                             output_partition_sizes,
+                                             input_size, output_size,
+                                             params_dtype)
+        if num_experts == 1:
+            return linear_weights
+        for name, param in tuple(linear_weights.items()):
+            if isinstance(param, Parameter):
+                repeat_size = (num_experts, ) + (1, ) * param.dim()
+                new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size),
+                                      requires_grad=False)
+                set_weight_attrs(new_param, param.__dict__)
+                linear_weights[name] = new_param
+        return linear_weights
+
+    @abstractmethod
+    def apply_moe_weights(self, w1: Dict[str,
+                                         torch.Tensor], w2: Dict[str,
+                                                                 torch.Tensor],
+                          x: torch.Tensor, gating_output: torch.Tensor,
+                          topk: int, renormalize: bool) -> torch.Tensor:
+        """Apply the weights to the input tensor."""
+        raise NotImplementedError
+
 
 class UnquantizedLinearMethod(LinearMethodBase):
     """Linear method without quantization.
@@ -72,7 +169,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         weight = weights["weight"]
         if self.separate_bias_add:
-            if bias:
+            if bias is not None:
                 return F.linear(x, weight) + bias
             return F.linear(x, weight)
         return F.linear(x, weight, bias)
@@ -82,6 +179,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
         weight = weights["weight"]
         return F.embedding(x, weight)
 
+    def apply_moe_weights(self, w1: Dict[str,
+                                         torch.Tensor], w2: Dict[str,
+                                                                 torch.Tensor],
+                          x: torch.Tensor, gating_output: torch.Tensor,
+                          topk: int, renormalize: bool) -> torch.Tensor:
+        return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk,
+                         renormalize)
+
 
 class ReplicatedLinear(torch.nn.Module):
     """Replicated linear layer.
@@ -156,6 +261,7 @@ class ColumnParallelLinear(torch.nn.Module):
         linear_method: (Maybe quantized) linear method.
         output_sizes: list of output sizes packed into one output, like for
                       QKV the list would be size 3.
+        num_experts: number of experts for sparse moe models.
     """
 
     def __init__(
@@ -168,6 +274,7 @@ class ColumnParallelLinear(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         linear_method: Optional[LinearMethodBase] = None,
         output_sizes: Optional[List[int]] = None,
+        num_experts: int = 1,
     ):
         super().__init__()
 
@@ -187,10 +294,10 @@ class ColumnParallelLinear(torch.nn.Module):
         if output_sizes is None:
             output_sizes = [output_size]
         self.linear_method = linear_method
-        self.linear_weights = self.linear_method.create_weights(
-            self.input_size, [x // tp_size for x in output_sizes],
+        self.num_experts = num_experts
+        self.linear_weights = self.linear_method.create_moe_weights(
+            num_experts, self.input_size, [x // tp_size for x in output_sizes],
             self.input_size, self.output_size, self.params_dtype)
-
         for name, weight in self.linear_weights.items():
             if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
@@ -206,11 +313,21 @@ class ColumnParallelLinear(torch.nn.Module):
         else:
             self.register_parameter("bias", None)
 
-    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+    def weight_loader(self,
+                      param: Parameter,
+                      loaded_weight: torch.Tensor,
+                      expert_id: int = -1):
         tp_rank = get_tensor_model_parallel_rank()
         tp_size = get_tensor_model_parallel_world_size()
         output_dim = getattr(param, "output_dim", None)
         param_data = param.data
+        if self.num_experts > 1:
+            if expert_id >= 0:
+                param_data = param_data[expert_id]
+            # Loaded weight is packed at expert dim
+            else:
+                output_dim = output_dim + 1
+
         if output_dim is not None:
             if loaded_weight.shape[output_dim] % tp_size != 0:
                 raise ValueError(
@@ -270,22 +387,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
         linear_method: Optional[LinearMethodBase] = None,
+        num_experts: int = 1,
     ):
         self.output_sizes = output_sizes
         tp_size = get_tensor_model_parallel_world_size()
         assert all(output_size % tp_size == 0 for output_size in output_sizes)
         super().__init__(input_size, sum(output_sizes), bias, gather_output,
                          skip_bias_add, params_dtype, linear_method,
-                         self.output_sizes)
+                         self.output_sizes, num_experts)
 
     def weight_loader(self,
                       param: Parameter,
                       loaded_weight: torch.Tensor,
-                      loaded_shard_id: Optional[int] = None):
-
+                      loaded_shard_id: Optional[int] = None,
+                      expert_id: int = -1):
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)
         is_metadata = getattr(param, "is_metadata", False)
+        if self.num_experts > 1:
+            if expert_id >= 0:
+                param_data = param_data[expert_id]
+            # Loaded weight is packed at expert dim
+            elif output_dim is not None:
+                output_dim = output_dim + 1
         if loaded_shard_id is None:
             # Loaded weight is already packed.
             if output_dim is None:
@@ -328,8 +452,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
 
-                # If marlin, we need to adjust the offset and size to account
-                # for the tiling.
+                # If marlin, we need to adjust the offset and size to
+                # account for the tiling.
                 shard_size, shard_offset = adjust_marlin_shard(
                     param, shard_size, shard_offset)
 
@@ -408,7 +532,6 @@ class QKVParallelLinear(ColumnParallelLinear):
         input_size = self.hidden_size
         output_size = (self.num_heads +
                        2 * self.num_kv_heads) * tp_size * self.head_size
-
         super().__init__(input_size, output_size, bias, False, skip_bias_add,
                          params_dtype, linear_method, [
                              self.num_heads * tp_size * self.head_size,
@@ -476,8 +599,8 @@ class QKVParallelLinear(ColumnParallelLinear):
                 shard_size = shard_size // param.pack_factor
                 shard_offset = shard_offset // param.pack_factor
 
-                # If marlin, we need to adjust the offset and size to account
-                # for the tiling.
+                # If marlin, we need to adjust the offset and size to
+                # account for the tiling.
                 shard_size, shard_offset = adjust_marlin_shard(
                     param, shard_size, shard_offset)
 
@@ -543,6 +666,7 @@ class RowParallelLinear(torch.nn.Module):
         params_dtype: Optional[torch.dtype] = None,
         reduce_results: bool = True,
         linear_method: Optional[LinearMethodBase] = None,
+        num_experts: int = 1,
     ):
         super().__init__()
         # Keep input parameters
@@ -561,9 +685,10 @@ class RowParallelLinear(torch.nn.Module):
         if linear_method is None:
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
-        self.linear_weights = self.linear_method.create_weights(
-            self.input_size_per_partition, [self.output_size], self.input_size,
-            self.output_size, self.params_dtype)
+        self.num_experts = num_experts
+        self.linear_weights = self.linear_method.create_moe_weights(
+            num_experts, self.input_size_per_partition, [self.output_size],
+            self.input_size, self.output_size, self.params_dtype)
         for name, weight in self.linear_weights.items():
             if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
@@ -583,11 +708,30 @@ class RowParallelLinear(torch.nn.Module):
         else:
             self.register_parameter("bias", None)
 
-    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        if self.is_exl2():
+            self.uninitialized = [
+                "q_invperm", "q_weight", "q_scale", "q_scale_max", "q_groups"
+            ]
+
+    def is_exl2(self) -> bool:
+        return hasattr(
+            self.linear_method, "quant_config"
+        ) and self.linear_method.quant_config.get_name() == "exl2"
+
+    def weight_loader(self,
+                      param: Parameter,
+                      loaded_weight: torch.Tensor,
+                      expert_id: int = -1):
         tp_rank = get_tensor_model_parallel_rank()
         tp_size = get_tensor_model_parallel_world_size()
         input_dim = getattr(param, "input_dim", None)
         param_data = param.data
+        if self.num_experts > 1:
+            if expert_id >= 0:
+                param_data = param_data[expert_id]
+            # Loaded weight is packed at expert dim
+            elif input_dim is not None:
+                input_dim = input_dim + 1
         if input_dim is not None:
             if loaded_weight.shape[input_dim] % tp_size != 0:
                 raise ValueError(
@@ -598,14 +742,33 @@ class RowParallelLinear(torch.nn.Module):
             loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                  shard_size)
         if isinstance(param, torch.nn.parameter.UninitializedParameter):
-            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
+            if self.is_exl2():
+                self.device = param.device
+                param.materialize(loaded_weight.shape,
+                                  dtype=loaded_weight.dtype,
+                                  device="cpu")
+            else:
+                param.materialize(loaded_weight.shape,
+                                  dtype=loaded_weight.dtype)
             param_data = param.data
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
+        if self.is_exl2():
+            weights = self.linear_weights
+            self.uninitialized = [
+                key for key in self.uninitialized if isinstance(
+                    weights[key], torch.nn.parameter.UninitializedParameter)
+            ]
+            if not self.uninitialized:
+                # shard and release memory ASAP to reduce peak memory usage
+                post_init_exl2(self)
+
     def forward(self, input_):
-        # Set up backprop all-reduce.
-        if self.input_is_parallel:
+        is_exl2 = self.is_exl2()
+        if self.input_is_parallel and is_exl2:
+            input_parallel = tensor_model_parallel_all_gather(input_)
+        elif self.input_is_parallel or is_exl2:
             input_parallel = input_
         else:
             tp_rank = get_tensor_model_parallel_rank()

+ 113 - 0
aphrodite/modeling/layers/logits_processor.py

@@ -0,0 +1,113 @@
+"""A layer that compute logits from hidden_stats."""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from aphrodite.distributed import tensor_model_parallel_gather
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class LogitsProcessor(nn.Module):
+    """Process logits and apply logits processors from sampling metadata.
+
+    This layer does the following:
+    1. Gather logits from model hidden_states.
+    2. Scale logits if needed.
+    3. Apply logits processors (if any).
+    """
+
+    def __init__(self,
+                 vocab_size: int,
+                 org_vocab_size: Optional[int] = None,
+                 scale: Optional[float] = 1.0,
+                 logits_as_input: bool = False) -> None:
+        """
+        Args:
+            scale: A scaling factor to apply to the logits.
+        """
+        super().__init__()
+        self.scale = scale
+        self.vocab_size = vocab_size
+        # Whether the input is logits (default is hidden states).
+        self.logits_as_input = logits_as_input
+        # original vocabulary size (without LoRA).
+        if org_vocab_size is not None:
+            self.org_vocab_size = min(org_vocab_size, vocab_size)
+        else:
+            self.org_vocab_size = vocab_size
+
+    def forward(
+        self,
+        lm_head: nn.Module,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+        embedding_bias: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if self.logits_as_input:
+            logits = hidden_states
+        else:
+            hidden_states = _prune_hidden_states(hidden_states,
+                                                 sampling_metadata)
+            # Get the logits for the next tokens.
+            logits = self._get_logits(hidden_states, lm_head, embedding_bias)
+
+        if logits is not None:
+            logits *= self.scale
+
+            # Apply logits processors (if any).
+            logits = _apply_logits_processors(logits, sampling_metadata)
+
+        return logits
+
+    def _get_logits(self, hidden_states: torch.Tensor, lm_head: nn.Module,
+                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
+        # Get the logits for the next tokens.
+        logits = lm_head(hidden_states)
+        if embedding_bias is not None:
+            logits += embedding_bias
+        logits = tensor_model_parallel_gather(logits)
+        # Remove paddings in vocab (if any).
+        if logits is not None:
+            logits = logits[:, :self.org_vocab_size]
+        return logits
+
+
+def _prune_hidden_states(
+    hidden_states: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    return hidden_states.index_select(0,
+                                      sampling_metadata.selected_token_indices)
+
+
+def _apply_logits_processors(
+    logits: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    logits_row_idx = 0
+    found_logits_processors = False
+    for i, seq_group in enumerate(sampling_metadata.seq_groups):
+        seq_ids, sampling_params = seq_group
+        logits_processors = sampling_params.logits_processors
+        # handle prompt_logprobs by skipping rows in logits added for
+        # the prompt tokens (prompt logprobs are not processed)
+        if (i < sampling_metadata.num_prompts
+                and sampling_params.prompt_logprobs is not None):
+            assert len(seq_ids) == 1
+            logits_row_idx += sampling_metadata.prompt_lens[i] - 1
+        if logits_processors:
+            found_logits_processors = True
+            for seq_id in seq_ids:
+                logits_row = logits[logits_row_idx]
+                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
+                for logits_processor in logits_processors:
+                    logits_row = logits_processor(token_ids, logits_row)
+                logits[logits_row_idx] = logits_row
+                logits_row_idx += 1
+        else:
+            logits_row_idx += len(seq_ids)
+    if found_logits_processors:
+        # Ensure that no rows in logits were unexpectedly skipped.
+        assert logits_row_idx == logits.shape[0]
+    return logits

+ 0 - 36
aphrodite/modeling/layers/quantization/__init__.py

@@ -1,36 +0,0 @@
-from typing import Type
-# ruff: noqa: E501
-from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
-from aphrodite.modeling.layers.quantization.aqlm import AQLMConfig
-from aphrodite.modeling.layers.quantization.awq import AWQConfig
-from aphrodite.modeling.layers.quantization.bitsandbytes import BitsandBytesConfig
-from aphrodite.modeling.layers.quantization.exl2 import Exl2Config
-from aphrodite.modeling.layers.quantization.gguf import GGUFConfig
-from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
-from aphrodite.modeling.layers.quantization.quip import QuipConfig
-from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
-from aphrodite.modeling.layers.quantization.marlin import MarlinConfig
-
-_QUANTIZATION_CONFIG_REGISTRY = {
-    "aqlm": AQLMConfig,
-    "awq": AWQConfig,
-    "bnb": BitsandBytesConfig,
-    "exl2": Exl2Config,
-    "gguf": GGUFConfig,
-    "gptq": GPTQConfig,
-    "quip": QuipConfig,
-    "squeezellm": SqueezeLLMConfig,
-    "marlin": MarlinConfig,
-}
-
-
-def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
-    if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
-        raise ValueError(f"Invalid quantization method: {quantization}")
-    return _QUANTIZATION_CONFIG_REGISTRY[quantization]
-
-
-__all__ = [
-    "QuantizationConfig",
-    "get_quantization_config",
-]

+ 19 - 22
aphrodite/modeling/layers/rejection.py

@@ -1,15 +1,15 @@
-from typing import Tuple, Optional
 from functools import cached_property
+from typing import Optional, Tuple
 
 import torch
-import torch.nn as nn
 import torch.jit
+import torch.nn as nn
 
 
 class RejectionSampler(nn.Module):
     """Apply modified rejection sampling as described in "Accelerating Large
-    Language Model Decoding with Speculative Sampling"
-    https://arxiv.org/pdf/2302.01318.pdf.
+        Language Model Decoding with Speculative Sampling"
+        https://arxiv.org/pdf/2302.01318.pdf.
     """
 
     def __init__(self, strict_mode: bool = False):
@@ -144,6 +144,7 @@ class RejectionSampler(nn.Module):
         recovered_probs = self._get_recovered_probs(
             target_probs, draft_probs).reshape(batch_size * k, vocab_size)
 
+        # NOTE: the recovered_probs are overwritten by this method.
         recovered_token_ids = _multinomial(recovered_probs,
                                            num_samples=1).reshape(
                                                batch_size, k)
@@ -194,8 +195,7 @@ class RejectionSampler(nn.Module):
                                   device=target_probs.device)
         capped_ratio = torch.minimum(
             selected_target_probs / selected_draft_probs,
-            torch.full((1, ), 1, device=target_probs.device),
-        )
+            torch.full((1, ), 1, device=target_probs.device))
         accepted = uniform_rand < capped_ratio
 
         return accepted
@@ -291,8 +291,7 @@ class RejectionSampler(nn.Module):
         output_with_bonus_tokens = -torch.ones(
             (batch_size, k + self._num_bonus_tokens),
             dtype=self.token_id_dtype,
-            device=accepted.device,
-        )
+            device=accepted.device)
         output = output_with_bonus_tokens[:, :k]
 
         # Fill in the first k columns of the output tensor using masks and data
@@ -306,6 +305,11 @@ class RejectionSampler(nn.Module):
         output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
                                                       bonus_token_ids, -1)
 
+        # We disable bonus tokens because it causes corrupt KV cache for
+        # proposal methods that require KV cache. We can fix it by "prefilling"
+        # the bonus token in the proposer.
+        output_with_bonus_tokens[:, -1] = -1
+
         # Fill the recovered token ids.
         output.mul_(~after_false_mask).add_(
             recovered_token_ids.mul(after_false_mask))
@@ -323,11 +327,8 @@ class RejectionSampler(nn.Module):
         draft_probs: torch.Tensor,
         draft_token_ids: torch.Tensor,
     ) -> None:
-        (
-            target_batch_size,
-            num_target_probs,
-            target_vocab_size,
-        ) = target_probs.shape
+        (target_batch_size, num_target_probs,
+         target_vocab_size) = target_probs.shape
         bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
         draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
         draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
@@ -363,12 +364,8 @@ class RejectionSampler(nn.Module):
         draft_token_ids: torch.Tensor,
     ) -> None:
         devices = [
-            t.device for t in [
-                target_probs,
-                bonus_token_ids,
-                draft_probs,
-                draft_token_ids,
-            ]
+            t.device for t in
+            [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
         ]
         assert all([devices[0] == device for device in devices])
 
@@ -397,8 +394,8 @@ def _multinomial(
     if num_samples > 1:
         # This is equivalent to torch.repeat_interleaved (which also
         # forces a GPU<->CPU sync).
-        probs = (probs[:, None, :].expand(probs.shape[0], num_samples,
-                                          probs.shape[1]).contiguous().view(
-                                              -1, probs.shape[1]))
+        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
+                                         probs.shape[1]).contiguous().view(
+                                             -1, probs.shape[1])
     q = torch.empty_like(probs).exponential_(1.0)
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)

+ 2 - 2
aphrodite/modeling/layers/rotary_embedding.py

@@ -111,7 +111,7 @@ class RotaryEmbedding(nn.Module):
             query_pass = query[..., self.rotary_dim:]
             key_pass = key[..., self.rotary_dim:]
 
-        self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
+        self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
         cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
                                      if offsets is not None else positions]
         cos, sin = cos_sin.chunk(2, dim=-1)
@@ -147,7 +147,7 @@ class RotaryEmbedding(nn.Module):
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # ops.rotary_embedding() is an in-place operation that
         # updates the query and key tensors.
-        self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
+        self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
         # ops.rotary_embedding()/batched_rotary_embedding() are in-place ops
         # that update the qk tensors
         if offsets is not None:

+ 318 - 247
aphrodite/modeling/layers/sampler.py

@@ -1,225 +1,155 @@
 """A layer that samples the next tokens from the model's outputs."""
+import itertools
+import math
 from typing import Dict, List, Tuple, Optional
 
 import torch
 import torch.nn as nn
-import math
 
 from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
                                                   OutputMetadata,
                                                   SamplingTensors)
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_gather)
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
                                        SamplerOutput, SequenceData,
                                        SequenceGroupOutput, SequenceOutput)
 from aphrodite.modeling.layers.ops.sample import sample as sample_triton
-from aphrodite.common.utils import is_neuron
 
 
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
-
     This layer does the following:
     1. Discard the hidden states that are not used for sampling (i.e., all
         tokens except the final one in each prompt).
     2. Compute the logits for the next tokens.
-    3. Apply presence and frequency penalties.
-    4. Apply temperature scaling.
-    5. Apply top-p and top-k truncation.
-    6. Sample the next tokens.
+    3. Apply all the different sampler functions in the specified order.
+    4. Sample the next tokens.
     Here, each sequence group within the batch can have different sampling
     parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
+    The structure of the logits tensor is coupled with the seq_groups in
+    sampling_metadata. Typically, each sequence in each seq_group has one row in
+    logits for the next token to be sampled; however, for a seq_group with a
+    prompt request with the prompt_logprobs sampling parameter, there are rows
+    in logits for each token in the input prompt.
     """
 
-    def __init__(self,
-                 vocab_size: int,
-                 org_vocab_size: Optional[int] = None) -> None:
+    def __init__(self):
         super().__init__()
-        self.vocab_size = vocab_size
-        # Transformers-neuronx generates outputs as logits directly
-        self.logits_as_hidden_states = is_neuron()
-        # original vocabulary size (without LoRA).
-        self.org_vocab_size = org_vocab_size or vocab_size
-
-    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
-        # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
-        if embedding_bias is not None:
-            logits += embedding_bias
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.org_vocab_size]
-        return logits
-
-    def forward(
-        self,
-        embedding: torch.Tensor,
-        hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-        embedding_bias: Optional[torch.Tensor] = None,
-    ) -> Optional[SamplerOutput]:
-        # Get the hidden states that we use for sampling.
-        if self.logits_as_hidden_states:
-            logits = hidden_states
-        else:
-            hidden_states = _prune_hidden_states(hidden_states,
-                                                 sampling_metadata)
-
-            # Get the logits for the next tokens.
-            logits = self._get_logits(hidden_states, embedding, embedding_bias)
 
-        return _perform_sampling(logits, sampling_metadata)
-
-
-# FIXME: This is a hack for the missing GPU blocks. This should be removed
-# once a proper fix is implemented.
-class QuantSampler(nn.Module):
-    """Samples the next tokens from the model's outputs.
-
-    This layer does the following:
-    1. Discard the hidden states that are not used for sampling (i.e., all
-        tokens except the final one in each prompt).
-    2. Compute the logits for the next tokens.
-    3. Apply presence and frequency penalties.
-    4. Apply temperature scaling.
-    5. Apply top-p and top-k truncation.
-    6. Sample the next tokens.
-    Here, each sequence group within the batch can have different sampling
-    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
-    """
-
-    def __init__(self,
-                 vocab_size: int,
-                 org_vocab_size: Optional[int] = None) -> None:
-        super().__init__()
-        self.vocab_size = vocab_size
-        # original vocabulary size (without LoRA).
-        self.org_vocab_size = org_vocab_size or vocab_size
-
-    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
-        # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
-        if embedding_bias is not None:
-            logits += embedding_bias
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.org_vocab_size]
-        return logits
+        # Whether or not the SamplerOutput should have on-device tensors
+        # containing the sampled token ids and probabilities. This is used by
+        # speculative decoding.
+        self.include_gpu_probs_tensor = False
 
     def forward(
         self,
         logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        # Get the hidden states that we use for sampling.
-        logits = _prune_hidden_states(logits, sampling_metadata)
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.vocab_size]
-
-        return _perform_sampling(logits, sampling_metadata)
-
-
-def _perform_sampling(
-        logits: torch.Tensor,
-        sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
-    # Only perform sampling in the driver worker.
-    # Note: `_get_logits` is still distributed across TP workers because
-    # the `embedding` weight is distributed across TP workers.
-    # TODO: Change the get_logits part to a separate stage.
-    if not sampling_metadata.perform_sampling:
-        return None
-
-    assert logits is not None
-    _, vocab_size = logits.shape
-
-    output_metadata = OutputMetadata()
-
-    # Apply logits processors (if any)
-    logits = _apply_logits_processors(logits, sampling_metadata)
-
-    # Prepare sampling tensors with pinned memory to avoid blocking.
-    sampling_tensors = SamplingTensors.from_sampling_metadata(
-        sampling_metadata, vocab_size, logits.device, logits.dtype)
-
-    if sampling_tensors.do_penalties:
-        logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
-                                  sampling_tensors.output_tokens,
-                                  sampling_tensors.pres_penalties,
-                                  sampling_tensors.freq_penalties,
-                                  sampling_tensors.rep_penalties)
-
-    if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
-        logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                    sampling_tensors.dynatemp_mins,
-                                    sampling_tensors.dynatemp_maxs,
-                                    sampling_tensors.dynatemp_exps)
-
-    if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
-            or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
-        logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
-                                      sampling_tensors.top_ks,
-                                      sampling_tensors.top_as,
-                                      sampling_tensors.min_ps)
-
-    if sampling_tensors.do_tfss:
-        logits = _apply_tfs(logits, sampling_tensors.tfss)
-    if sampling_tensors.do_eta_cutoffs:
-        logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
-    if sampling_tensors.do_epsilon_cutoffs:
-        logits = _apply_epsilon_cutoff(logits,
-                                       sampling_tensors.epsilon_cutoffs)
-    if sampling_tensors.do_typical_ps:
-        logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
-
-    if sampling_tensors.do_quadratic:
-        logits = _apply_quadratic_sampling(logits,
-                                           sampling_tensors.smoothing_indices,
-                                           sampling_tensors.smoothing_factors,
-                                           sampling_tensors.smoothing_curves)
-
-    banned_tokens = _get_custom_token_bans(sampling_metadata)
-    assert len(banned_tokens) == logits.shape[0]
-    logits = _apply_token_bans(logits, banned_tokens)
-    if sampling_tensors.do_mirostat:
-        logits = _apply_mirostat_v2(logits, sampling_tensors)
-
-    # We use float32 for probabilities and log probabilities.
-    # Compute the probabilities.
-    probs = torch.softmax(logits, dim=-1, dtype=torch.float)
-    # Compute the log probabilities.
-    # Use log_softmax to ensure numerical stability.
-    logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
-
-    # Sample the next tokens.
-    sample_results = _sample(probs, logprobs, sampling_metadata,
-                             sampling_tensors)
-
-    if sampling_tensors.do_mirostat:
-        _mirostat_store_args(logits, sampling_tensors, sample_results,
-                             sampling_metadata, output_metadata)
-    # Get the logprobs query results.
-    prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
-                                                     sampling_metadata,
-                                                     sample_results)
-    return _build_sampler_output(sample_results, sampling_metadata,
-                                 prompt_logprobs, sample_logprobs,
-                                 output_metadata)
-
-
-def _prune_hidden_states(
-    hidden_states: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
-) -> torch.Tensor:
-    return hidden_states.index_select(0,
-                                      sampling_metadata.selected_token_indices)
+        assert logits is not None
+        _, vocab_size = logits.shape
+        output_metadata = OutputMetadata()
+        # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
+        # have not been generated yet
+        logits = _apply_min_tokens_penalty(logits, sampling_metadata)
+
+        # Prepare sampling tensors with pinned memory to avoid blocking.
+        sampling_tensors = SamplingTensors.from_sampling_metadata(
+            sampling_metadata, vocab_size, logits.device, logits.dtype)
+
+        if sampling_tensors.do_penalties:
+            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
+                                      sampling_tensors.output_tokens,
+                                      sampling_tensors.pres_penalties,
+                                      sampling_tensors.freq_penalties,
+                                      sampling_tensors.rep_penalties)
+
+        if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
+            logits = _apply_temperature(logits, sampling_tensors.temperatures,
+                                        sampling_tensors.dynatemp_mins,
+                                        sampling_tensors.dynatemp_maxs,
+                                        sampling_tensors.dynatemp_exps)
+
+        if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
+                or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
+            logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
+                                          sampling_tensors.top_ks,
+                                          sampling_tensors.top_as,
+                                          sampling_tensors.min_ps)
+        if sampling_tensors.do_tfss:
+            logits = _apply_tfs(logits, sampling_tensors.tfss)
+        if sampling_tensors.do_eta_cutoffs:
+            logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
+        if sampling_tensors.do_epsilon_cutoffs:
+            logits = _apply_epsilon_cutoff(logits,
+                                           sampling_tensors.epsilon_cutoffs)
+        if sampling_tensors.do_typical_ps:
+            logits = _apply_typical_sampling(logits,
+                                             sampling_tensors.typical_ps)
+
+        if sampling_tensors.do_quadratic:
+            logits = _apply_quadratic_sampling(
+                logits, sampling_tensors.smoothing_indices,
+                sampling_tensors.smoothing_factors,
+                sampling_tensors.smoothing_curves)
+
+        banned_tokens = _get_custom_token_bans(sampling_metadata)
+        assert len(banned_tokens) == logits.shape[0]
+        logits = _apply_token_bans(logits, banned_tokens)
+        if sampling_tensors.do_mirostat:
+            logits = _apply_mirostat_v2(logits, sampling_tensors)
+
+        # We use float32 for probabilities and log probabilities.
+        # Compute the probabilities.
+        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
+        # Compute the log probabilities.
+        # Use log_softmax to ensure numerical stability.
+        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
+
+        # Sample the next tokens.
+        # sample_results = _sample(probs, logprobs, sampling_metadata,
+        #                          sampling_tensors)
+        sample_results, maybe_sampled_tokens_tensor = _sample(
+            probs,
+            logprobs,
+            sampling_metadata,
+            sampling_tensors,
+            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
+            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
+        )
+
+        if self.include_gpu_probs_tensor:
+            assert maybe_sampled_tokens_tensor is not None
+            sampled_tokens_tensor = maybe_sampled_tokens_tensor
+            on_device_tensors = (probs, sampled_tokens_tensor)
+        else:
+            on_device_tensors = None
+
+        if sampling_tensors.do_mirostat:
+            _mirostat_store_args(logits, sampling_tensors, sample_results,
+                                 sampling_metadata, output_metadata)
+        # Get the logprobs query results.
+        prompt_logprobs, sample_logprobs = _get_logprobs(
+            logprobs, sampling_metadata, sample_results)
+        # return _build_sampler_output(sample_results, sampling_metadata,
+        #                              prompt_logprobs, sample_logprobs,
+        #                              output_metadata)
+        return _build_sampler_output(sample_results, sampling_metadata,
+                                     prompt_logprobs, sample_logprobs,
+                                     output_metadata, on_device_tensors)
+
+    @property
+    def _should_modify_greedy_probs_inplace(self) -> bool:
+        """Whether or not the sampler should modify the probability distribution
+        of greedily-sampled tokens such that multinomial sampling would sample
+        the greedily-sampled token.
+        In other words, if True then we set the probability of the greedily-
+        sampled token to 1.
+        This is used by speculative decoding, which requires that the sampling
+        method be encoded into the probability distribution.
+        """
+        # Modify greedy probs if include_gpu_probs_tensor is set.
+        return self.include_gpu_probs_tensor
 
 
 def _get_bin_counts_and_mask(
@@ -255,35 +185,6 @@ def _get_custom_token_bans(
     return banned_tokens
 
 
-def _apply_logits_processors(
-    logits: torch.Tensor,
-    metadata: SamplingMetadata,
-) -> torch.Tensor:
-    assert metadata.seq_groups is not None
-    assert metadata.prompt_lens is not None
-    assert metadata.seq_data is not None
-    seq_offset = 0
-    for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
-        seq_size = len(seq_ids)
-        output_tokens = []
-        if (i < metadata.num_prompts
-                and sampling_params.prompt_logprobs is not None):
-            prompt_seqs = metadata.prompt_lens[i] - 1
-            seq_size += prompt_seqs
-            output_tokens.extend([[]] * prompt_seqs)
-        seq_end = seq_offset + seq_size
-
-        if sampling_params.logits_processors:
-            output_tokens.extend(metadata.seq_data[sid].output_token_ids
-                                 for sid in seq_ids)
-            for proc in sampling_params.logits_processors:
-                proc(logits[seq_offset:seq_end], output_tokens)
-
-        seq_offset = seq_end
-
-    return logits
-
-
 def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                      output_tokens_tensor: torch.Tensor,
                      presence_penalties: torch.Tensor,
@@ -316,6 +217,44 @@ def _apply_token_bans(logits: torch.Tensor,
     return logits
 
 
+def _apply_min_tokens_penalty(
+    logits: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    assert sampling_metadata.seq_groups is not None
+    assert sampling_metadata.seq_data is not None
+    # list of indices in logits that will be set to -inf
+    logits_to_penalize = []
+    start_idx = 0
+    for seq_ids, sampling_params in sampling_metadata.seq_groups:
+        min_tokens = sampling_params.min_tokens
+        if min_tokens > 0:
+            seqs_to_penalize = []
+            for i, seq_id in enumerate(seq_ids):
+                seq_data = sampling_metadata.seq_data[seq_id]
+                if len(seq_data.output_token_ids) < min_tokens:
+                    seqs_to_penalize.append(i)
+
+            if seqs_to_penalize:
+                # convert to the index into logits
+                seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
+                # use set() to remove any duplicates
+                token_ids_to_penalize = set(sampling_params.stop_token_ids +
+                                            [sampling_params.eos_token_id])
+                # itertools.product pairs each seq index with every token id
+                logits_to_penalize.extend(
+                    itertools.product(seqs_to_penalize, token_ids_to_penalize))
+
+        start_idx += len(seq_ids)
+
+    if logits_to_penalize:
+        # use zip and * to group indices along each dimension
+        # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
+        logits[tuple(zip(*logits_to_penalize))] = -float("inf")
+
+    return logits
+
+
 def _apply_alphabet_soup(
     logits: torch.Tensor,
     p: torch.Tensor,
@@ -645,6 +584,7 @@ def _multinomial(
     if seq_groups is None:
         q.exponential_()
     else:
+        assert generators is not None
         sample_idx = 0
         for (seq_ids, _), generator in zip(seq_groups, generators):
             next_sample_idx = sample_idx + len(seq_ids) * num_samples
@@ -657,7 +597,9 @@ def _sample_with_torch(
     probs: torch.Tensor,
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
-) -> List[Tuple[List[int], List[int]]]:
+    include_gpu_probs_tensor: bool,
+    modify_greedy_probs: bool,
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
     """Returns list of (selected_tokens, parent_seq_ids) tuples
     corresponding to sampling_metadata.seq_groups."""
     assert sampling_metadata.seq_groups is not None
@@ -674,6 +616,15 @@ def _sample_with_torch(
     sample_metadata = {}
     multinomial_samples = {}
 
+    # Create output tensor for sampled token ids.
+    if include_gpu_probs_tensor:
+        sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
+                                               1,
+                                               dtype=torch.long,
+                                               device=logprobs.device)
+    else:
+        sampled_token_ids_tensor = None
+
     # Counterintuitively, having two loops here is actually faster.
     # The first loop can run without waiting on GPU<->CPU sync.
     for sampling_type, sample_indices in categorized_sample_indices.items():
@@ -685,9 +636,23 @@ def _sample_with_torch(
         is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
         sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
                                           is_prompts, sample_indices)
+        long_sample_indices = sample_indices.long()
         if sampling_type == SamplingType.GREEDY:
-            greedy_samples = torch.argmax(logprobs[sample_indices.long()],
+            greedy_samples = torch.argmax(logprobs[long_sample_indices],
                                           dim=-1)
+
+            if include_gpu_probs_tensor:
+                # Store sampled tokens in output tensor.
+                sampled_token_ids_tensor[
+                    long_sample_indices] = greedy_samples.unsqueeze(-1)
+
+            if modify_greedy_probs:
+                # If required, modify the probabilities such that sampling from
+                # the modified distribution would always sample the argmax
+                # token id.
+                _modify_greedy_probs_inplace(logprobs, probs,
+                                             long_sample_indices,
+                                             greedy_samples)
         elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
             max_best_of_in_batch = 1
             for seq_group, is_prompt in zip(seq_groups, is_prompts):
@@ -700,14 +665,20 @@ def _sample_with_torch(
                 "generators": sampling_metadata.generators,
             }
             multinomial_samples[sampling_type] = _multinomial(
-                probs[sample_indices.long()], max_best_of_in_batch,
+                probs[long_sample_indices], max_best_of_in_batch,
                 **seeded_args)
+
+            if include_gpu_probs_tensor:
+                # Store sampled tokens in output tensor.
+                sampled_token_ids_tensor[
+                    long_sample_indices] = multinomial_samples[sampling_type]
         elif sampling_type == SamplingType.BEAM:
             beam_search_logprobs = logprobs[sample_indices]
         else:
             raise ValueError(f"Unsupported sampling type: {sampling_type}")
 
     # GPU<->CPU sync happens in the loop below.
+    # This also converts the sample output to Python objects.
 
     for sampling_type, metadata in sample_metadata.items():
         seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
@@ -726,7 +697,7 @@ def _sample_with_torch(
         sample_results_dict[i]
         for i in range(len(sampling_metadata.seq_groups))
     ]
-    return sample_results
+    return sample_results, sampled_token_ids_tensor
 
 
 def _sample_with_triton_kernel(
@@ -736,6 +707,7 @@ def _sample_with_triton_kernel(
     sampling_tensors: SamplingTensors,
 ) -> List[Tuple[List[int], List[int]]]:
     assert sampling_metadata.seq_groups is not None
+    assert sampling_metadata.categorized_sample_indices is not None
     assert sampling_metadata.seq_data is not None
     categorized_seq_group_ids = {t: [] for t in SamplingType}
     categorized_sample_indices = sampling_metadata.categorized_sample_indices
@@ -750,7 +722,6 @@ def _sample_with_triton_kernel(
 
     # Counterintuitively, having two loops here is actually faster.
     # The first loop can run without waiting on GPU<->CPU sync.
-    assert categorized_sample_indices is not None
     for sampling_type, sample_indices in categorized_sample_indices.items():
         sampled_token_indices = sample_indices[:, 1]
         sample_indices = sample_indices[:, 0]
@@ -812,18 +783,40 @@ def _sample_with_triton_kernel(
 
 
 def _sample(
-    probs: torch.Tensor,
-    logprobs: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
-    sampling_tensors: SamplingTensors,
-) -> List[Tuple[List[int], List[int]]]:
-    return _sample_with_torch(probs, logprobs, sampling_metadata)
+    probs: torch.Tensor, logprobs: torch.Tensor,
+    sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
+    include_gpu_probs_tensor: bool, modify_greedy_probs: bool
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+    return _sample_with_torch(
+        probs,
+        logprobs,
+        sampling_metadata,
+        include_gpu_probs_tensor=include_gpu_probs_tensor,
+        modify_greedy_probs=modify_greedy_probs,
+    )
 
     # TODO: Enable once Triton kernel & associated code is faster.
     # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
     #                                   sampling_tensors)
 
 
+def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+    """
+    This function calculates the ranks of the chosen tokens in a logprob tensor.
+    Args:
+        x (torch.Tensor): 2D logprob tensor of shape (N, M)
+                        where N is the no. of tokens and M is the vocab dim.
+        indices (torch.Tensor): List of chosen token indices.
+    Returns:
+        torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
+                    Each element in the returned tensor represents the rank 
+                    of the chosen token in the input logprob tensor.
+    """
+    vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
+             indices]
+    return (x > vals[:, None]).long().sum(1).add_(1)
+
+
 def _get_logprobs(
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
@@ -836,7 +829,8 @@ def _get_logprobs(
     # Prepare query indices
     batched_logprobs_query_seq_indices: List[int] = []
     batched_logprobs_query_token_indices: List[int] = []
-    largest_num_logprobs = 0
+    # at least get one logprob for each token
+    largest_num_logprobs = 1
     sample_idx = 0
     for i, (seq_group, sample_result) in enumerate(
             zip(sampling_metadata.seq_groups, sample_results)):
@@ -864,12 +858,18 @@ def _get_logprobs(
         sample_idx += num_parent_seqs
     assert sample_idx == logprobs.size(0)
 
+    batched_logprobs_query_seq_indices_gpu = torch.tensor(
+        batched_logprobs_query_seq_indices, device=logprobs.device)
+    batched_logprobs_query_token_indices_gpu = torch.tensor(
+        batched_logprobs_query_token_indices, device=logprobs.device)
     # Batched query for logprobs of selected token
     batched_logprobs_query_result = logprobs[[
-        batched_logprobs_query_seq_indices,
-        batched_logprobs_query_token_indices
+        batched_logprobs_query_seq_indices_gpu,
+        batched_logprobs_query_token_indices_gpu
     ]]
-
+    batched_ranks_query_result = _get_ranks(
+        logprobs[batched_logprobs_query_seq_indices_gpu],
+        batched_logprobs_query_token_indices_gpu)
     # Batched query for logprobs of topk tokens
     if largest_num_logprobs > 0:
         top_logprobs, top_token_ids = torch.topk(logprobs,
@@ -882,6 +882,8 @@ def _get_logprobs(
 
     batched_logprobs_query_result = batched_logprobs_query_result.cpu()
 
+    batched_ranks_query_result = batched_ranks_query_result.cpu()
+
     # Gather results
     result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
     result_sample_logprobs: List[SampleLogprobs] = []
@@ -891,7 +893,6 @@ def _get_logprobs(
             zip(sampling_metadata.seq_groups, sample_results)):
         seq_ids, sampling_params = seq_group
         next_token_ids, parent_ids = sample_result
-
         # Prompt logprobs
         if (i < sampling_metadata.num_prompts
                 and sampling_params.prompt_logprobs is not None):
@@ -902,22 +903,26 @@ def _get_logprobs(
             for token_id in prompt_tokens[1:]:
                 prompt_logprobs_dict = {
                     token_id:
-                    batched_logprobs_query_result[query_result_idx].item()
+                    (batched_logprobs_query_result[query_result_idx].item(),
+                     batched_ranks_query_result[query_result_idx].item())
                 }
                 if num_logprobs > 0:
                     prompt_logprobs_dict.update(
-                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
-                            top_logprobs[sample_idx, :num_logprobs].tolist()))
+                        zip(
+                            top_token_ids[sample_idx, :num_logprobs].tolist(),
+                            zip(
+                                top_logprobs[
+                                    sample_idx, :num_logprobs].tolist(),
+                                range(1, num_logprobs + 1))))
                 group_prompt_logprobs.append({
-                    token_id: Logprob(logprob)
-                    for token_id, logprob in prompt_logprobs_dict.items()
+                    token_id: Logprob(*logprob_rank)
+                    for token_id, logprob_rank in prompt_logprobs_dict.items()
                 })
                 sample_idx += 1
                 query_result_idx += 1
             result_prompt_logprobs.append(group_prompt_logprobs)
         else:
             result_prompt_logprobs.append(None)
-
         # Sample logprobs
         num_logprobs = sampling_params.logprobs
         if num_logprobs is None:
@@ -926,33 +931,89 @@ def _get_logprobs(
         for next_token_id, parent_id in zip(next_token_ids, parent_ids):
             sample_logprobs_dict = {
                 next_token_id:
-                batched_logprobs_query_result[query_result_idx].item()
+                (batched_logprobs_query_result[query_result_idx].item(),
+                 batched_ranks_query_result[query_result_idx].item())
             }
             query_result_idx += 1
-            if num_logprobs > 0:
+            if num_logprobs >= 0:
                 sample_logprobs_dict.update(
                     zip(
                         top_token_ids[sample_idx +
                                       parent_id, :num_logprobs].tolist(),
-                        top_logprobs[sample_idx +
-                                     parent_id, :num_logprobs].tolist()))
+                        zip(
+                            top_logprobs[sample_idx +
+                                         parent_id, :num_logprobs].tolist(),
+                            range(1, num_logprobs + 1))))
             group_sample_logprobs.append({
-                token_id: Logprob(logprob)
-                for token_id, logprob in sample_logprobs_dict.items()
+                token_id: Logprob(*logprob_rank)
+                for token_id, logprob_rank in sample_logprobs_dict.items()
             })
         result_sample_logprobs.append(group_sample_logprobs)
         sample_idx += len(seq_ids)
-
     return result_prompt_logprobs, result_sample_logprobs
 
 
+def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
+                                 sample_indices: torch.Tensor,
+                                 greedy_samples: torch.Tensor) -> None:
+    """Modify the probability distributions of the greedily-sampled tokens such
+    that each sampled token has a "probability" of 1.0. This is required by
+    speculative decoding, which depends on the sampling method being encoded
+    within the probability distribution for correctness.
+    # Why do we only need to do this for greedy sampling?
+    Aphrodite's sampler performs the following steps for greedy or multinomial
+    (random) sampling:
+        1. Get logits from model.
+        2. Modify logits according to per-sequence sampling parameters.
+            - Multiply by temperature, top-k and top-p masking, penalize tokens
+                according to their frequency, etc.
+        3. Sample a token.
+            - Random sampling simply samples from the modified probability
+                distribution.
+            - Greedy sampling performs `argmax` to obtain the token with the
+                highest likelihood.
+    
+    Ignoring greedy sampling for a moment, we find that the computed probability
+    distribution has the following property: we can sample from it independently
+    and find that the token sampled by the Sampler has a frequency corresponding
+    to how often we see it in our sampling. In other words, for tokens sampled
+    with Aphrodite's random SamplingType, the computed probability distribution
+    encodes the sampling methodology completely.
+    Greedy sampling does not normally have this property. Aphrodite modifies
+    logits according to sampling params, then performs `argmax`, then returns
+    the sampled token and the computed probability distribution. If we sample
+    from the distribution, we'll find the likelihood of the greedily-sampled
+    token is not always 1.0.
+    Since lossless speculative decoding requires that the sampling methodology
+    be encoded within the probability distribution, we are motivated to modify
+    the probability distribution such that the sampled token has probability 1
+    when speculative decoding is used.
+    NOTE: Alternatively, we could use an extremely low temperature to achieve
+    greedy sampling using multinomial computation and unite the codepaths. This
+    has implications on the overall design of the sampler, e.g. how to record
+    accurate logprobs for the user, so this improvement is deferred to later.
+    """
+    logprobs[sample_indices, :] = -float('inf')
+    logprobs[sample_indices, greedy_samples] = 0.0
+    probs[sample_indices, :] = 0
+    probs[sample_indices, greedy_samples] = 1.0
+
+
 def _build_sampler_output(
     sample_results: List[Tuple[List[int], List[int]]],
     sampling_metadata: SamplingMetadata,
     prompt_logprobs: List[Optional[PromptLogprobs]],
     sample_logprobs: List[SampleLogprobs],
     output_metadata: OutputMetadata,
+    on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
 ) -> SamplerOutput:
+    """Construct Python objects with the output of sampling.
+    Args:
+        on_device_tensors: Tuple containing on-device tensors with the
+            probabilities used in sampling and the sampled token ids. This
+            allows post-processing without copies to CPU/serialization, e.g. in
+            speculative decoding rejection sampling.
+    """
     assert sampling_metadata.seq_groups is not None
     sampler_output = []
     for (seq_group, sample_result, group_prompt_logprobs,
@@ -969,7 +1030,17 @@ def _build_sampler_output(
 
         sampler_output.append(
             SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
-    return SamplerOutput(outputs=sampler_output)
+    # If not specified, store None values in SamplerOutput.
+    if on_device_tensors is not None:
+        sampled_token_probs, sampled_token_ids = on_device_tensors
+    else:
+        sampled_token_probs, sampled_token_ids = (None, None)
+
+    return SamplerOutput(
+        outputs=sampler_output,
+        sampled_token_probs=sampled_token_probs,
+        sampled_token_ids=sampled_token_ids,
+    )
 
 
 def _apply_mirostat_v2(logits: torch.Tensor,

+ 41 - 11
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -3,14 +3,13 @@ from typing import Optional, Sequence
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
-from aphrodite.modeling.megatron.parallel_state import (
+from aphrodite.distributed import (
+    divide,
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
+    tensor_model_parallel_all_reduce,
 )
-from aphrodite.modeling.megatron.utils import divide
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce)
+from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
 from aphrodite.modeling.utils import set_weight_attrs
 
 DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -91,16 +90,24 @@ class VocabParallelEmbedding(torch.nn.Module):
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         output_dim = getattr(param, "output_dim", None)
+        packed_dim = getattr(param, "packed_dim", None)
         if output_dim is not None:
-            assert loaded_weight.shape[output_dim] == self.org_vocab_size
-            loaded_weight = loaded_weight.narrow(
-                output_dim, self.vocab_start_index,
-                min(self.vocab_end_index - self.vocab_start_index,
-                    self.org_vocab_size - self.vocab_start_index))
+            shard_offset = self.vocab_start_index
+            shard_size = min(self.vocab_end_index,
+                             self.org_vocab_size) - shard_offset
+            if packed_dim == output_dim:
+                shard_size = shard_size // param.pack_factor
+                shard_offset = shard_offset // param.pack_factor
+            loaded_weight = loaded_weight.narrow(output_dim, shard_offset,
+                                                 shard_size)
         if isinstance(param, torch.nn.parameter.UninitializedParameter):
             vocab_shape = list(loaded_weight.shape)
             if output_dim is not None:
-                vocab_shape[output_dim] = self.num_embeddings_per_partition
+                if packed_dim == output_dim:
+                    vocab_shape[output_dim] = (
+                        self.num_embeddings_per_partition // param.pack_factor)
+                else:
+                    vocab_shape[output_dim] = self.num_embeddings_per_partition
             param.materialize(vocab_shape, dtype=loaded_weight.dtype)
         if output_dim is not None:
             param.data.narrow(
@@ -123,6 +130,7 @@ class VocabParallelEmbedding(torch.nn.Module):
         output_parallel = self.linear_method.apply_embedding(
             self.linear_weights, masked_input)
         # output_parallel = F.embedding(masked_input, self.weight)
+        # Mask the output embedding.
         if self.tp_size > 1:
             output_parallel[input_mask, :] = 0.0
         # Reduce across all the model parallel GPUs.
@@ -132,9 +140,11 @@ class VocabParallelEmbedding(torch.nn.Module):
 
 class ParallelLMHead(VocabParallelEmbedding):
     """Parallelized LM head.
+
     Output logits weight matrices used in the Sampler. The weight and bias
     tensors are padded to make sure they are divisible by the number of
     model parallel GPUs.
+
     Args:
         num_embeddings: vocabulary size.
         embedding_dim: size of hidden state.
@@ -171,3 +181,23 @@ class ParallelLMHead(VocabParallelEmbedding):
         if self.bias is not None:
             logits += self.bias
         return logits
+
+
+class ParallelTWEHead(torch.nn.Module):
+    """Parallelized tie word embeddings head.
+
+    Output logits weight matrices used in the Sampler. The weight and bias
+    tensors are read from a VocabParallelEmbedding.
+
+    Args:
+        embeddings: the VocabParallelEmbedding to mirror
+    """
+
+    def __init__(self, embeddings: VocabParallelEmbedding):
+        super().__init__()
+        self.linear_method = embeddings.linear_method
+        self.linear_weights = embeddings.linear_weights
+
+    def forward(self, input_):
+        logits = self.linear_method.apply_weights(self.linear_weights, input_)
+        return logits

+ 16 - 14
aphrodite/modeling/loader.py

@@ -10,17 +10,22 @@ import torch.nn as nn
 
 from aphrodite.common.config import DeviceConfig, ModelConfig
 from aphrodite.modeling.models import ModelRegistry
+from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
 from aphrodite.modeling.hf_downloader import (
     get_quant_config,
     initialize_dummy_weights,
 )
-from aphrodite.modeling.layers.quantization.bitsandbytes import (
+from aphrodite.quantization.bitsandbytes import (
     BNBLinearMethod,
     replace_quant_params,
 )
-from aphrodite.modeling.megatron.parallel_state import (
+from aphrodite.distributed import (
     get_tensor_model_parallel_world_size, )
 
+_VISION_MODEL_CLASSES = [
+    LlavaForConditionalGeneration,
+]
+
 
 @contextlib.contextmanager
 def _set_default_torch_dtype(dtype: torch.dtype):
@@ -33,11 +38,6 @@ def _set_default_torch_dtype(dtype: torch.dtype):
 
 def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
     architectures = getattr(model_config.hf_config, "architectures", [])
-    # Special handling for quantized Mixtral.
-    # FIXME: This is a temporary hack.
-    if (model_config.quantization is not None
-            and "MixtralForCausalLM" in architectures):
-        architectures = ["QuantMixtralForCausalLM"]
 
     for arch in architectures:
         model_cls = ModelRegistry.load_model_cls(arch)
@@ -51,6 +51,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
 def get_model(model_config: ModelConfig, device_config: DeviceConfig,
               **kwargs) -> nn.Module:
     lora_config = kwargs.get("lora_config", None)
+    vision_language_config = kwargs.get("vision_language_config", None)
     model_class = _get_model_architecture(model_config)
 
     # Get the (maybe quantized) linear method.
@@ -88,19 +89,20 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
                     "be added in the future. If this is important to you, "
                     "please open an issue on github.")
             else:
-                model = model_class(model_config.hf_config, linear_method)
+                if model_class not in _VISION_MODEL_CLASSES:
+                    model = model_class(model_config.hf_config, linear_method)
+                else:
+                    model = model_class(model_config.hf_config,
+                                        vision_language_config, linear_method)
         if model_config.load_format == "dummy":
             # NOTE: For accurate performance evaluation, we assign
             # random values to the weights.
             initialize_dummy_weights(model)
         else:
             # Load the weights from the cached or downloaded files.
-            model.load_weights(
-                model_config.model,
-                model_config.download_dir,
-                model_config.load_format,
-                model_config.revision,
-            )
+            model.load_weights(model_config.model, model_config.download_dir,
+                               model_config.load_format, model_config.revision)
+
         if isinstance(linear_method, BNBLinearMethod):
             replace_quant_params(
                 model,

+ 0 - 1
aphrodite/modeling/megatron/README.md

@@ -1 +0,0 @@
-The files here are from the NVIDIA [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) repository, but only with inference-related code.

+ 0 - 127
aphrodite/modeling/megatron/cupy_utils.py

@@ -1,127 +0,0 @@
-"""CuPy utilities for all-reduce.
-We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
-CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
-CUDA graphs.
-NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
-TODO: Remove this file when torch.distributed.all_reduce is fixed.
-"""
-import contextlib
-
-import torch
-from torch.distributed import ReduceOp
-
-try:
-    import cupy
-    from cupy.cuda import nccl
-    from cupyx.distributed import NCCLBackend
-except ImportError as e:
-    cupy = e
-    nccl = None
-
-    class NCCLBackend:
-        ...
-
-
-_OP_MAPPING = {
-    ReduceOp.SUM: "sum",
-    ReduceOp.PRODUCT: "prod",
-    ReduceOp.MIN: "min",
-    ReduceOp.MAX: "max",
-}
-
-
-class NCCLBackendWithBFloat16(NCCLBackend):
-    # This is enough to add bfloat16 support for most operations,
-    # but broadcast will fail (will require changes in compiled
-    # cupy code).
-    def _get_nccl_dtype_and_count(self, array, count=None):
-        nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
-        torch_dtype = getattr(array, "_torch_dtype", None)
-        if torch_dtype is torch.bfloat16:
-            nccl_dtype = nccl.NCCL_BFLOAT16
-        return nccl_dtype, count
-
-    def barrier(self) -> None:
-        raise RuntimeError(
-            "Currently, CuPy NCCL barrier is not supported since the TCP "
-            "store is immediately stopped after the initialization.")
-
-
-_NCCL_BACKEND = None
-_WORLD_SIZE = 0
-
-
-def is_initialized() -> bool:
-    """Returns whether the NCCL backend is initialized."""
-    return _NCCL_BACKEND is not None
-
-
-@contextlib.contextmanager
-def set_cupy_stream(stream: torch.cuda.Stream):
-    """Set the cuda stream for communication"""
-    cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
-                                           stream.device_index)
-    with cupy_stream:
-        yield
-
-
-def init_process_group(world_size: int, rank: int, host: str,
-                       port: int) -> None:
-    """Initializes the CuPy NCCL backend.
-    # TODO: handle NCCL timeouts.
-    """
-    assert not is_initialized()
-
-    if isinstance(cupy, Exception):
-        raise ImportError(
-            "NCCLBackend is not available. Please install cupy.") from cupy
-
-    # TODO: Create TP and PP process groups for CuPy.
-    global _NCCL_BACKEND
-    global _WORLD_SIZE
-    assert world_size > 0, f"{world_size=} should be a positive integer"
-    assert 0 <= rank < world_size, (
-        f"{rank=} should be a integer between [0, {world_size})")
-
-    cupy.cuda.runtime.setDevice(torch.cuda.current_device())
-    _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
-    _WORLD_SIZE = world_size
-
-    # Stop the TCP store to prevent the deadlock issues at termination time.
-    # FIXME: This is hacky. Find a more robust solution.
-    if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
-        _NCCL_BACKEND._store.stop()
-
-
-def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
-    """All-reduces the input tensor across the process group."""
-    assert input_.is_cuda, f"{input_} should be a cuda tensor"
-    # Hack to support bfloat16
-    torch_dtype = input_.dtype
-    if torch_dtype is torch.bfloat16:
-        # We need to view as float16, otherwise
-        # cupy will fail. This will not change
-        # the underlying data.
-        input_ = input_.view(torch.float16)
-    cupy_input = cupy.asarray(input_)
-    cupy_input._torch_dtype = torch_dtype  # pylint: disable=protected-access
-    _NCCL_BACKEND.all_reduce(in_array=cupy_input,
-                             out_array=cupy_input,
-                             op=_OP_MAPPING[op])
-
-
-def destroy_process_group() -> None:
-    """Destroys the NCCL backend."""
-    global _NCCL_BACKEND
-    global _WORLD_SIZE
-    _NCCL_BACKEND = None
-    _WORLD_SIZE = 0
-
-
-def get_world_size() -> int:
-    """Returns the world size."""
-    return _WORLD_SIZE
-
-
-def get_nccl_backend():
-    return _NCCL_BACKEND

+ 0 - 49
aphrodite/modeling/megatron/utils.py

@@ -1,49 +0,0 @@
-# Copyright 2023 The PygmalionAI team.
-# Copyright 2023 The vLLM team.
-# Adapted from
-# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
-# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-from typing import Sequence
-
-import torch
-
-
-def ensure_divisibility(numerator, denominator):
-    """Ensure that numerator is divisible by the denominator."""
-    assert numerator % denominator == 0, "{} is not divisible by {}".format(
-        numerator, denominator)
-
-
-def divide(numerator, denominator):
-    """Ensure that numerator is divisible by the denominator and return
-    the division value."""
-    ensure_divisibility(numerator, denominator)
-    return numerator // denominator
-
-
-def split_tensor_along_last_dim(
-    tensor: torch.Tensor,
-    num_partitions: int,
-    contiguous_split_chunks: bool = False,
-) -> Sequence[torch.Tensor]:
-    """ Split a tensor along its last dimension.
-
-        Arguments:
-            tensor: input tensor.
-            num_partitions: number of partitions to split the tensor
-            contiguous_split_chunks: If True, make each chunk contiguous
-                                     in memory.
-
-        Returns:
-            A list of Tensors
-    """
-    # Get the size and dimension.
-    last_dim = tensor.dim() - 1
-    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
-    # Split.
-    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
-    # NOTE: torch.split does not create contiguous tensors by default.
-    if contiguous_split_chunks:
-        return tuple(chunk.contiguous() for chunk in tensor_list)
-
-    return tensor_list

+ 0 - 97
aphrodite/modeling/metadata.py

@@ -1,97 +0,0 @@
-from dataclasses import dataclass, fields
-from typing import Optional, List, Any, Dict
-
-import torch
-from xformers.ops.fmha.attn_bias import AttentionBias
-
-
-@dataclass
-class InputMetadata:
-    """Metadata for input sequences. Used in PagedAttention.
-
-    NOTE: Any python object stored here is not updated when it is
-    cuda-graph replayed. If you have values that need to be changed
-    dynamically, it should be stored in tensor. The tensor has to be
-    updated from `CUDAGraphRunner.forward` API.
-    """
-    # Currently, input sequences can only contain all prompts
-    # or all decoding. True if all sequences are prompts.
-    is_prompt: bool
-    # (num_tokens,). The indices of the token slots that input tokens will be
-    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
-    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
-    # in block 0, and 1st slot in block 1, respectively.
-    slot_mapping: torch.Tensor
-    # (batch_size,). The prompt length per sequence. None if it is a decoding.
-    prompt_lens: Optional[List[int]]
-    # prompt_lens stored as a tensor.
-    prompt_lens_tensor: Optional[torch.Tensor]
-    # The number of prompt tokens. Doesn't include padding.
-    num_prompt_tokens: int
-    # The number of generation tokens. Doesn't include padding.
-    num_generation_tokens: int
-    """
-    Definition of context_len, subquery_len, and seqlen.
-    |---------- N-1 iteration --------|
-    |---------------- N iteration ---------------------|
-    |- tokenA -|......................|-- newTokens ---|
-    |---------- context_len ----------|
-    |-------------------- seqlen ----------------------|
-                                      |- subquery_len -|
-    WARNING: context_len has different definition depending on if it is
-    prefill vs decoding. When it is prefill, it doesn't include new
-    tokens. When it is for decoding, it includes a new token.
-    """
-
-    # Maximum subquery length in the batch.
-    max_subquery_len: Optional[int]
-    # Maximum context length in the batch.
-    max_context_len: Optional[int]
-    # FIXME: It is for flash attn.
-    # Maximum sequence length in the batch.
-    max_seq_len: Optional[int]
-    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
-    # the batch, used to index into subquery. E.g., if the subquery length
-    # is [4, 6], it is [0, 4, 10].
-    subquery_start_loc: Optional[torch.Tensor]
-    # FIXME: It is for flash attn.
-    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
-    # the batch, used to index into sequence. E.g., if the sequence length is
-    # [4, 6], it is [0, 4, 10].
-    seq_start_loc: Optional[torch.Tensor]
-    # (batch_size,). The length of context (tokens stored in KV cache) per
-    # sequence. WARNING: When it is a prefill request, it doesn't include new
-    # tokens. When it is for decoding, it includes a new token.
-    context_lens: Optional[torch.Tensor]
-    # (batch_size, max_blocks_per_seq).
-    # Block addresses per sequence. (Seq id -> list of physical block)
-    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
-    # in the kv cache. Each block can contain up to block_size tokens.
-    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
-    # captured.
-    block_tables: Optional[torch.Tensor]
-    # Whether or not if cuda graph is enabled.
-    # Cuda-graph is currently enabled for decoding only.
-    use_cuda_graph: bool
-    kv_cache_dtype: str
-
-    def __post_init__(self):
-        # Set during the execution of the first attention op.
-        # It is a list because it is needed to set per prompt
-        # when alibi slopes is used. It is because of the limitation
-        # from xformer API.
-        # will not appear in the __repr__ and __init__
-        self.attn_bias: Optional[List[AttentionBias]] = None
-
-        # Cuda graph is only used for decoding now.
-        if self.use_cuda_graph:
-            assert self.num_prompt_tokens == 0
-
-    def asdict_zerocopy(self) -> Dict[str, Any]:
-        """Similar to dataclasses.asdict, but avoids deepcopying."""
-        # Note that if we add dataclasses as fields, they will need
-        # similar handling.
-        return {
-            field.name: getattr(self, field.name)
-            for field in fields(self)
-        }

+ 22 - 15
aphrodite/modeling/models/__init__.py

@@ -1,10 +1,10 @@
 import importlib
-from typing import List, Optional, Type
+from typing import Dict, List, Optional, Type
 from loguru import logger
 
 import torch.nn as nn
 
-from aphrodite.common.utils import is_hip, is_neuron
+from aphrodite.common.utils import is_hip
 
 # Architecture -> (module, class).
 _MODELS = {
@@ -15,6 +15,7 @@ _MODELS = {
     "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
     "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
     "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
+    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
     "CohereForCausalLM": ("cohere", "CohereForCausalLM"),
     "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
     "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
@@ -29,9 +30,10 @@ _MODELS = {
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
     # For decapoda-research/llama-*
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
+    "LlavaForConditionalGeneration":
+    ("llava", "LlavaForConditionalGeneration"),
     "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
-    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
     "YiForCausalLM": ("llama", "LlamaForCausalLM"),
     # transformers's mpt class has lower case
     "MptForCausalLM": ("mpt", "MPTForCausalLM"),
@@ -41,11 +43,16 @@ _MODELS = {
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
     "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
+    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
     "RWForCausalLM": ("falcon", "FalconForCausalLM"),
     "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
     "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
 }
 
+# Architecture -> type.
+# out of tree models
+_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
+
 # Models not supported by ROCm.
 _ROCM_UNSUPPORTED_MODELS = []
 
@@ -60,16 +67,13 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
     "Sliding window attention is not yet supported in ROCm's flash attention",
 }
 
-_NEURON_SUPPORTED_MODELS = {
-    "LlamaForCausalLM": "neuron.llama",
-    "MistralForCausalLM": "neuron.mistral",
-}
-
 
 class ModelRegistry:
 
     @staticmethod
     def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
+        if model_arch in _OOT_MODELS:
+            return _OOT_MODELS[model_arch]
         if model_arch not in _MODELS:
             return None
         if is_hip():
@@ -81,15 +85,8 @@ class ModelRegistry:
                 logger.warning(
                     f"Model architecture {model_arch} is partially supported "
                     "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
-        elif is_neuron():
-            if model_arch not in _NEURON_SUPPORTED_MODELS:
-                raise ValueError(
-                    f"Model architecture {model_arch} is not supported by "
-                    "AWS Neuron for now.")
 
         module_name, model_cls_name = _MODELS[model_arch]
-        if is_neuron():
-            module_name = _NEURON_SUPPORTED_MODELS[model_arch]
         module = importlib.import_module(
             f"aphrodite.modeling.models.{module_name}")
         return getattr(module, model_cls_name, None)
@@ -98,6 +95,16 @@ class ModelRegistry:
     def get_supported_archs() -> List[str]:
         return list(_MODELS.keys())
 
+    @staticmethod
+    def register_model(model_arch: str, model_cls: Type[nn.Module]):
+        if model_arch in _MODELS:
+            logger.warning(
+                f"Model architecture {model_arch} is already registered, "
+                "and will be overwritten by the new model "
+                f"class {model_cls.__name__}.")
+        global _OOT_MODELS
+        _OOT_MODELS[model_arch] = model_cls
+
 
 __all__ = [
     "ModelRegistry",

+ 27 - 29
aphrodite/modeling/models/baichuan.py

@@ -25,9 +25,8 @@ from typing import List, Optional, Tuple
 import torch
 from torch import nn
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -37,12 +36,13 @@ from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
 )
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding,
     ParallelLMHead,
 )
-from aphrodite.modeling.megatron.parallel_state import (
+from aphrodite.distributed import (
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
 )
@@ -54,8 +54,6 @@ from aphrodite.modeling.hf_downloader import (
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
 
-KVCache = Tuple[torch.Tensor, torch.Tensor]
-
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
     closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
@@ -211,15 +209,14 @@ class BaiChuanAttention(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         qkv, _ = self.W_pack(hidden_states)
         q, k, v = qkv.chunk(chunks=3, dim=-1)
         if self.postion_embedding != "ALIBI":
             q, k = self.rotary_emb(positions, q, k)
-        k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -260,8 +257,8 @@ class BaiChuanDecoderLayer(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
@@ -275,7 +272,7 @@ class BaiChuanDecoderLayer(nn.Module):
             positions=positions,
             hidden_states=hidden_states,
             kv_cache=kv_cache,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
 
         # Fully Connected
@@ -311,8 +308,8 @@ class BaiChuanModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         residual = None
@@ -322,7 +319,7 @@ class BaiChuanModel(nn.Module):
                 positions,
                 hidden_states,
                 kv_caches[i],
-                input_metadata,
+                attn_metadata,
                 residual,
             )
         hidden_states, _ = self.norm(hidden_states, residual)
@@ -344,32 +341,33 @@ class BaiChuanBaseForCausalLM(nn.Module):
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
                                       linear_method=linear_method)
-        self.sampler = Sampler(config.vocab_size)
-        self.quant_sampler = QuantSampler(config.vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size,
+                                                config.tokenizer_vocab_size)
+        self.sampler = Sampler()
 
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   input_metadata)
+                                   attn_metadata)
         return hidden_states
 
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
     def sample(
         self,
-        hidden_states: torch.Tensor,
+        logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            next_tokens = self.quant_sampler(self.lm_head(hidden_states),
-                                             sampling_metadata)
-        else:
-            next_tokens = self.sampler(self.lm_head.weight, hidden_states,
-                                       sampling_metadata)
+        next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
     def load_weights(

+ 30 - 32
aphrodite/modeling/models/bloom.py

@@ -18,31 +18,29 @@
 # limitations under the License.
 """Inference-only BLOOM model compatible with HuggingFace weights."""
 import math
-from typing import List, Optional, Tuple
+from typing import List, Optional
 
 import torch
 from torch import nn
 from transformers import BloomConfig
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
-from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
 from aphrodite.common.sequence import SamplerOutput
 
-KVCache = Tuple[torch.Tensor, torch.Tensor]
-
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
     closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
@@ -117,14 +115,13 @@ class BloomAttention(nn.Module):
         self,
         position_ids: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         del position_ids  # Unused.
         qkv, _ = self.query_key_value(hidden_states)
         q, k, v = qkv.chunk(chunks=3, dim=-1)
-        k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.dense(attn_output)
         return output
 
@@ -181,8 +178,8 @@ class BloomBlock(nn.Module):
         self,
         position_ids: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         # Layer norm at the beginning of the transformer layer.
         layernorm_output = self.input_layernorm(hidden_states)
@@ -198,7 +195,7 @@ class BloomBlock(nn.Module):
             position_ids=position_ids,
             hidden_states=layernorm_output,
             kv_cache=kv_cache,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
         attention_output = attention_output + residual
         layernorm_output = self.post_attention_layernorm(attention_output)
@@ -243,8 +240,8 @@ class BloomModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         position_ids: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.word_embeddings(input_ids)
         hidden_states = self.word_embeddings_layernorm(hidden_states)
@@ -254,7 +251,7 @@ class BloomModel(nn.Module):
                 position_ids,
                 hidden_states,
                 kv_caches[i],
-                input_metadata,
+                attn_metadata,
             )
         hidden_states = self.ln_f(hidden_states)
         return hidden_states
@@ -271,36 +268,37 @@ class BloomForCausalLM(nn.Module):
         self.config = config
         self.linear_method = linear_method
         self.transformer = BloomModel(config, linear_method)
-        # self.lm_head_weight = self.transformer.word_embeddings.weight
+        self.lm_head_weight = self.transformer.word_embeddings.weight
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
                                       linear_method=linear_method)
-        self.sampler = Sampler(config.vocab_size)
-        self.quant_sampler = QuantSampler(config.vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size,
+                                                config.tokenizer_vocab_size)
+        self.sampler = Sampler()
 
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         input_metadata)
+                                         attn_metadata)
         return hidden_states
 
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
     def sample(
         self,
-        hidden_states: torch.Tensor,
+        logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            next_tokens = self.quant_sampler(self.lm_head(hidden_states),
-                                             sampling_metadata)
-        else:
-            next_tokens = self.sampler(self.lm_head.weight, hidden_states,
-                                       sampling_metadata)
+        next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
     def load_weights(self,

+ 47 - 31
aphrodite/modeling/models/chatglm.py

@@ -2,34 +2,32 @@
 # Adapted from
 # https://github.com/THUDM/ChatGLM2-6B
 """Inference-only ChatGLM model compatible with THUDM weights."""
-from typing import List, Optional, Tuple
+from typing import List, Optional
 
 import torch
 from torch import nn
 from torch.nn import LayerNorm
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.distributed import (get_tensor_model_parallel_world_size)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.hf_downloader import (default_weight_loader,
                                               hf_model_weights_iterator)
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.transformers_utils.configs import ChatGLMConfig
 
-KVCache = Tuple[torch.Tensor, torch.Tensor]
-
 
 class GLMAttention(nn.Module):
 
@@ -98,20 +96,18 @@ class GLMAttention(nn.Module):
         self,
         hidden_states: torch.Tensor,
         position_ids: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         qkv, _ = self.query_key_value(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(position_ids, q, k)
-        key_cache, value_cache = kv_cache
         context_layer = self.attn(
             q,
             k,
             v,
-            key_cache,
-            value_cache,
-            input_metadata,
+            kv_cache,
+            attn_metadata,
         )
         attn_output, _ = self.dense(context_layer)
         return attn_output
@@ -199,8 +195,8 @@ class GLMBlock(nn.Module):
         self,
         hidden_states: torch.Tensor,
         position_ids: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         # hidden_states: [num_tokens, h]
         # Layer norm at the beginning of the transformer layer.
@@ -210,7 +206,7 @@ class GLMBlock(nn.Module):
             hidden_states=layernorm_output,
             position_ids=position_ids,
             kv_cache=kv_cache,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
 
         # Residual connection.
@@ -263,8 +259,8 @@ class GLMTransformer(nn.Module):
         self,
         hidden_states: torch.Tensor,
         position_ids: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         for i in range(self.num_layers):
             layer = self.layers[i]
@@ -272,7 +268,7 @@ class GLMTransformer(nn.Module):
                 hidden_states=hidden_states,
                 position_ids=position_ids,
                 kv_cache=kv_caches[i],
-                input_metadata=input_metadata,
+                attn_metadata=attn_metadata,
             )
         # Final layer norm.
         if self.post_layer_norm:
@@ -307,8 +303,8 @@ class ChatGLMModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         position_ids: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         inputs_embeds = self.embedding(input_ids)
 
@@ -317,43 +313,63 @@ class ChatGLMModel(nn.Module):
             hidden_states=inputs_embeds,
             position_ids=position_ids,
             kv_caches=kv_caches,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
         return hidden_states
 
 
 class ChatGLMForCausalLM(nn.Module):
+    packed_modules_mapping = {
+        "query_key_value": ["query_key_value"],
+        "dense_h_to_4h": ["dense_h_to_4h"]
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "query_key_value",
+        "dense",
+        "dense_h_to_4h",
+        "dense_4h_to_h",
+    ]
+    embedding_modules = {}
+    embedding_padding_modules = []
 
     def __init__(
         self,
         config: ChatGLMConfig,
         linear_method: Optional[LinearMethodBase] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config: ChatGLMConfig = config
         self.linear_method = linear_method
         self.transformer = ChatGLMModel(config, linear_method)
-        # self.lm_head_weight = self.transformer.output_layer.weight
-        self.sampler = Sampler(config.padded_vocab_size)
+        self.logits_processor = LogitsProcessor(config.padded_vocab_size,
+                                                config.tokenizer_vocab_size)
+        self.sampler = Sampler()
 
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         input_metadata)
+                                         attn_metadata)
         return hidden_states
 
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.transformer.output_layer,
+                                       hidden_states, sampling_metadata)
+        return logits
+
     def sample(
         self,
-        hidden_states: torch.Tensor,
+        logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(
-            self.transformer.output_layer(hidden_states), sampling_metadata)
+        next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
     def load_weights(self,

+ 181 - 71
aphrodite/modeling/models/cohere.py

@@ -25,35 +25,49 @@ from typing import List, Optional, Tuple
 import torch
 import torch.utils.checkpoint
 from torch import nn
-from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import Attention as Attention
 from aphrodite.modeling.layers.linear import (
+    ColumnParallelLinear,
     LinearMethodBase,
     MergedColumnParallelLinear,
     QKVParallelLinear,
     RowParallelLinear,
 )
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size, )
+    ParallelLMHead,
+    ParallelTWEHead,
+    VocabParallelEmbedding,
+)
+from aphrodite.distributed import (
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.hf_downloader import (
     default_weight_loader,
     hf_model_weights_iterator,
 )
 from aphrodite.common.sequence import SamplerOutput
 
-# limitations under the License.
-""" Cohere model configuration"""
 
-KVCache = Tuple[torch.Tensor, torch.Tensor]
+@torch.compile
+def layer_norm_func(hidden_states, weight, variance_epsilon):
+    input_dtype = hidden_states.dtype
+    hidden_states = hidden_states.to(torch.float32)
+    mean = hidden_states.mean(-1, keepdim=True)
+    variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
+    hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
+                                                         variance_epsilon)
+    hidden_states = weight.to(torch.float32) * hidden_states
+    return hidden_states.to(input_dtype)
 
 
 class LayerNorm(nn.Module):
@@ -61,23 +75,25 @@ class LayerNorm(nn.Module):
     def __init__(self, hidden_size, eps=1e-5, bias=False):
         super().__init__()
         self.weight = nn.Parameter(torch.ones(hidden_size))
-        self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
         self.variance_epsilon = eps
+        set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
 
     def forward(self, hidden_states, residuals=None):
-        input_dtype = hidden_states.dtype
-        hidden_states = hidden_states.to(torch.float32)
-        mean = hidden_states.mean(-1, keepdim=True)
-        variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
-        hidden_states = (hidden_states -
-                         mean) * torch.rsqrt(variance + self.variance_epsilon)
-        hidden_states = self.weight.to(torch.float32) * hidden_states
-        if self.bias is not None:
-            hidden_states = hidden_states + self.bias.to(torch.float32)
-        return hidden_states.to(input_dtype), residuals
-
-
-ALL_LAYERNORM_LAYERS.append(LayerNorm)
+        hidden_states = layer_norm_func(hidden_states, self.weight,
+                                        self.variance_epsilon)
+        return hidden_states, residuals
+
+    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        tp_rank = get_tensor_model_parallel_rank()
+        shard_dim = 0 if param.dim() != 1 else None
+        param_data = param.data
+        if shard_dim is not None:
+            shard_size = param_data.shape[shard_dim]
+            start_idx = tp_rank * shard_size
+            loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
+                                                 shard_size)
+        assert param_data.shape == loaded_weight.shape
+        param_data.copy_(loaded_weight)
 
 
 # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
@@ -92,12 +108,29 @@ class CohereMLP(nn.Module):
         self.config = config
         self.hidden_size = config.hidden_size
         self.intermediate_size = config.intermediate_size
-        self.gate_up_proj = MergedColumnParallelLinear(
-            self.hidden_size,
-            [self.intermediate_size] * 2,
-            bias=False,
-            linear_method=linear_method,
-        )
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
+            self.merge_weight = False
+            self.gate_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.up_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.intermediate_size,
+                bias=False,
+                linear_method=linear_method,
+            )
+        else:
+            self.merge_weight = True
+            self.gate_up_proj = MergedColumnParallelLinear(
+                self.hidden_size,
+                [self.intermediate_size] * 2,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.down_proj = RowParallelLinear(
             self.intermediate_size,
             self.hidden_size,
@@ -107,7 +140,12 @@ class CohereMLP(nn.Module):
         self.act_fn = SiluAndMul()
 
     def forward(self, x):
-        gate_up, _ = self.gate_up_proj(x)
+        if self.merge_weight:
+            gate_up, _ = self.gate_up_proj(x)
+        else:
+            up, _ = self.up_proj(x)
+            gate, _ = self.gate_proj(x)
+            gate_up = torch.cat([gate, up], dim=-1)
         x = self.act_fn(gate_up)
         x, _ = self.down_proj(x)
         return x
@@ -141,18 +179,43 @@ class CohereAttention(nn.Module):
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
-        self.max_position_embeddings = config.max_position_embeddings
+        self.max_position_embeddings = getattr(
+            config, "model_max_length", None) or getattr(
+                config, "max_position_embeddings", 8192)
         self.rope_theta = config.rope_theta
         self.rope_scaling = getattr(config, "rope_scaling", None)
-        self.is_causal = True
-        self.qkv_proj = QKVParallelLinear(
-            self.hidden_size,
-            self.head_dim,
-            self.total_num_heads,
-            self.total_num_kv_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
+        self.use_qk_norm = getattr(config, "use_qk_norm", False)
+        if (linear_method is not None
+                and not linear_method.quant_config.merge_weight()):
+            self.merge_weight = False
+            self.q_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.total_num_heads * self.head_dim,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.k_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.total_num_kv_heads * self.head_dim,
+                bias=False,
+                linear_method=linear_method,
+            )
+            self.v_proj = ColumnParallelLinear(
+                self.hidden_size,
+                self.total_num_kv_heads * self.head_dim,
+                bias=False,
+                linear_method=linear_method,
+            )
+        else:
+            self.merge_weight = True
+            self.qkv_proj = QKVParallelLinear(
+                self.hidden_size,
+                self.head_dim,
+                self.total_num_heads,
+                self.total_num_kv_heads,
+                bias=False,
+                linear_method=linear_method,
+            )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.hidden_size,
@@ -173,28 +236,53 @@ class CohereAttention(nn.Module):
             self.scaling,
             num_kv_heads=self.num_kv_heads,
         )
+        if self.use_qk_norm:
+            self.q_norm = LayerNorm(hidden_size=(self.num_heads,
+                                                 self.head_dim),
+                                    eps=config.layer_norm_eps)
+            self.k_norm = LayerNorm(hidden_size=(self.num_kv_heads,
+                                                 self.head_dim),
+                                    eps=config.layer_norm_eps)
+
+    def _apply_qk_norm(self, q, k):
+        q = q.view(*q.shape[:-1], -1, self.head_dim)
+        k = k.view(*k.shape[:-1], -1, self.head_dim)
+        q, _ = self.q_norm(q)
+        k, _ = self.k_norm(k)
+        q = q.view(*q.shape[:-2], -1)
+        k = k.view(*k.shape[:-2], -1)
+        return q, k
 
     def forward(
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.merge_weight:
+            qkv, _ = self.qkv_proj(hidden_states)
+            q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        else:
+            q, _ = self.q_proj(hidden_states)
+            k, _ = self.k_proj(hidden_states)
+            v, _ = self.v_proj(hidden_states)
+        if self.use_qk_norm:
+            q, k = self._apply_qk_norm(q, k)
         q, k = self.rotary_emb(positions, q, k)
-        k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
 
 class CohereDecoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: CohereConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+    def __init__(
+        self,
+        config: CohereConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
         super().__init__()
         self.hidden_size = config.hidden_size
 
@@ -208,8 +296,8 @@ class CohereDecoderLayer(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
@@ -219,7 +307,7 @@ class CohereDecoderLayer(nn.Module):
             positions=positions,
             hidden_states=hidden_states,
             kv_cache=kv_cache,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
         hidden_states_mlp = self.mlp(hidden_states)
         # Add everything together
@@ -229,12 +317,6 @@ class CohereDecoderLayer(nn.Module):
 
 
 class CohereModel(nn.Module):
-    """
-    Transformer decoder consisting of *config.num_hidden_layers* layers.
-    Each layer is a [`CohereDecoderLayer`]
-    Args:
-        config: CohereConfig
-    """
 
     def __init__(
         self,
@@ -245,7 +327,8 @@ class CohereModel(nn.Module):
         self.config = config
         self.vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size)
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
         self.layers = nn.ModuleList([
             CohereDecoderLayer(config, linear_method=linear_method)
             for _ in range(config.num_hidden_layers)
@@ -256,8 +339,8 @@ class CohereModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         residual = None
@@ -267,7 +350,7 @@ class CohereModel(nn.Module):
                 positions,
                 hidden_states,
                 kv_caches[i],
-                input_metadata,
+                attn_metadata,
                 residual,
             )
         hidden_states, _ = self.norm(hidden_states, residual)
@@ -283,30 +366,43 @@ class CohereForCausalLM(nn.Module):
     ) -> None:
         super().__init__()
         self.config = config
-        self.unpadded_vocab_size = config.vocab_size
         self.linear_method = linear_method
         self.model = CohereModel(config, linear_method)
-        self.sampler = Sampler(config.vocab_size)
+        if self.config.tie_word_embeddings:
+            self.lm_head = ParallelTWEHead(self.model.embed_tokens)
+        else:
+            self.lm_head = ParallelLMHead(config.vocab_size,
+                                          config.hidden_size,
+                                          linear_method=linear_method)
+        self.logits_processor = LogitsProcessor(config.vocab_size,
+                                                config.tokenizer_vocab_size,
+                                                scale=config.logit_scale)
+        self.sampler = Sampler()
 
     @torch.no_grad()
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   input_metadata)
+                                   attn_metadata)
         return hidden_states
 
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
     def sample(
         self,
-        hidden_states: torch.Tensor,
+        logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(self.model.embed_tokens.weight,
-                                   hidden_states, sampling_metadata)
+        next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
     def load_weights(
@@ -324,19 +420,33 @@ class CohereForCausalLM(nn.Module):
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
+        if (self.linear_method is not None
+                and not self.linear_method.quant_config.merge_weight()):
+            stacked_params_mapping = []
         params_dict = dict(self.named_parameters())
         loaded_params = set()
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
             for param_name, shard_name, shard_id in stacked_params_mapping:
                 if shard_name not in name:
                     continue
                 name = name.replace(shard_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
+                # lm_head is not used as it is tied with embed_token.
+                # To prevent errors, skip loading lm_head.
+                if self.config.tie_word_embeddings and "lm_head" in name:
+                    continue
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)

+ 423 - 0
aphrodite/modeling/models/dbrx.py

@@ -0,0 +1,423 @@
+# coding=utf-8
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.modeling.layers.fused_moe import fused_moe
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
+
+
+class DbrxRouter(nn.Module):
+    """A Router implementation for DBRX that returns logits for each expert
+    per token.
+    """
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        params_dtype: Optional[torch.dtype] = None,
+    ):
+        super().__init__()
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.num_total_experts = config.ffn_config.moe_num_experts
+        self.d_model = config.d_model
+        self.layer = ReplicatedLinear(
+            self.d_model,
+            self.num_total_experts,
+            bias=False,
+            params_dtype=params_dtype,
+            linear_method=None,
+        )
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        router_logits, _ = self.layer(hidden_states)
+        return router_logits
+
+
+class DbrxExperts(nn.Module):
+    """A tensor-parallel MoE implementation for DBRX.
+
+    Each expert's weights are sharded across all ranks and a fused MoE
+    kernel is used for the forward pass, and finally we reduce the outputs
+    across ranks.
+    """
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+        params_dtype: Optional[torch.dtype] = None,
+    ):
+        super().__init__()
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.num_total_experts = config.ffn_config.moe_num_experts
+        self.top_k = config.ffn_config.moe_top_k
+        self.d_model = config.d_model
+        self.intermediate_size = (config.ffn_config.ffn_hidden_size //
+                                  self.tp_size)
+
+        if params_dtype is None:
+            params_dtype = torch.get_default_dtype()
+        self.params_dtype = params_dtype
+
+        self.router = DbrxRouter(config, self.params_dtype)
+        self.ws = nn.Parameter(
+            torch.empty(
+                self.num_total_experts,
+                2 * self.intermediate_size,
+                self.d_model,
+                device="cuda",
+                dtype=self.params_dtype,
+            ))
+        self.w2s = nn.Parameter(
+            torch.empty(
+                self.num_total_experts,
+                self.d_model,
+                self.intermediate_size,
+                device="cuda",
+                dtype=self.params_dtype,
+            ))
+
+        set_weight_attrs(
+            self.ws,
+            {
+                "weight_loader": self.weight_loader,
+            },
+        )
+        set_weight_attrs(
+            self.w2s,
+            {
+                "weight_loader": self.weight_loader,
+            },
+        )
+
+    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
+                      weight_name: str):
+        tp_rank = get_tensor_model_parallel_rank()
+        param_data = param.data
+        shard_size = self.intermediate_size
+        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+        # DBRX uses GLU for each experts.
+        # GLU has 3 linear layers: w1, v1 and w2.
+        if weight_name.endswith("w1"):
+            loaded_weight = torch.reshape(
+                loaded_weight,
+                [-1, self.intermediate_size * self.tp_size, self.d_model],
+            )
+            param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
+        if weight_name.endswith("v1"):
+            loaded_weight = torch.reshape(
+                loaded_weight,
+                [-1, self.intermediate_size * self.tp_size, self.d_model],
+            )
+            param_data[:,
+                       shard_size:2 * shard_size, :] = loaded_weight[:,
+                                                                     shard, :]
+        if weight_name.endswith("w2"):
+            loaded_weight = torch.reshape(
+                loaded_weight,
+                [-1, self.intermediate_size * self.tp_size, self.d_model],
+            ).transpose(1, 2)
+            param_data[:] = loaded_weight[:, :, shard]
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        num_tokens, hidden_size = hidden_states.shape
+        hidden_states = hidden_states.view(-1, self.d_model)
+        # router_logits: (num_tokens, n_experts)
+        router_logits = self.router(hidden_states)
+        final_hidden_states = fused_moe(
+            hidden_states,
+            self.ws,
+            self.w2s,
+            router_logits,
+            self.top_k,
+            renormalize=True,
+            inplace=True,
+        )
+
+        if self.tp_size > 1:
+            final_hidden_states = tensor_model_parallel_all_reduce(
+                final_hidden_states)
+
+        return final_hidden_states.view(num_tokens, hidden_size)
+
+
+class DbrxAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.total_num_heads = config.n_heads
+        self.head_dim = self.d_model // self.total_num_heads
+        self.total_num_kv_heads = config.attn_config.kv_n_heads
+        self.clip_qkv = config.attn_config.clip_qkv
+        self.rope_theta = config.attn_config.rope_theta
+        self.max_position = config.max_seq_len
+
+        # pylint: disable=invalid-name
+        self.Wqkv = QKVParallelLinear(
+            self.d_model,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            linear_method=linear_method,
+        )
+        self.out_proj = RowParallelLinear(
+            self.d_model,
+            self.d_model,
+            bias=False,
+            linear_method=linear_method,
+        )
+        is_neox_style = (True if linear_method is None
+                         or linear_method.quant_config.rope_style() is None
+                         else linear_method.quant_config.rope_style())
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position,
+            base=int(self.rope_theta),
+            is_neox_style=is_neox_style,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        self.tp_size = tp_world_size
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.attn = Attention(
+            self.num_heads,
+            self.head_dim,
+            self.scaling,
+            num_kv_heads=self.num_kv_heads,
+        )
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.Wqkv(hidden_states)
+        if self.clip_qkv is not None:
+            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.rotary_emb(position_ids, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        hidden_states, _ = self.out_proj(attn_output)
+        return hidden_states
+
+
+class DbrxFusedNormAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.attn = DbrxAttention(config, linear_method)
+        self.norm_1 = nn.LayerNorm(self.d_model)
+        self.norm_2 = nn.LayerNorm(self.d_model)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.norm_1(hidden_states)
+        x = self.attn(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + x
+        residual = hidden_states
+        hidden_states = self.norm_2(hidden_states)
+        return hidden_states, residual
+
+
+class DbrxBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
+        self.ffn = DbrxExperts(config, linear_method)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states, residual = self.norm_attn_norm(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = self.ffn(hidden_states)
+        hidden_states = hidden_states + residual
+        return hidden_states
+
+
+class DbrxModel(nn.Module):
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          config.d_model,
+                                          linear_method=linear_method)
+        self.blocks = nn.ModuleList(
+            [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
+        self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
+        for module in self.modules():
+            if hasattr(module, "bias") and isinstance(module.bias,
+                                                      nn.Parameter):
+                # Remove the bias term in Linear and LayerNorm.
+                module.register_parameter("bias", None)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.wte(input_ids)
+        for i in range(len(self.blocks)):
+            block = self.blocks[i]
+            hidden_states = block(
+                position_ids,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+            )
+        hidden_states = self.norm_f(hidden_states)
+        return hidden_states
+
+
+class DbrxForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: DbrxConfig,
+        linear_method: Optional[LinearMethodBase] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.unpadded_vocab_size = config.vocab_size
+        self.transformer = DbrxModel(config, linear_method)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.d_model,
+                                      org_num_embeddings=config.vocab_size,
+                                      padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+                                      linear_method=linear_method)
+        self.logits_processor = LogitsProcessor(
+            self.unpadded_vocab_size,
+            min(config.vocab_size, config.tokenizer_vocab_size))
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.transformer(input_ids, positions, kv_caches,
+                                         attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(
+        self,
+        model_name_or_path: str,
+        cache_dir: Optional[str] = None,
+        load_format: str = "auto",
+        revision: Optional[str] = None,
+    ):
+        expert_params_mapping = [(
+            "ws" if weight_name in ["w1", "v1"] else "w2s",
+            f"experts.mlp.{weight_name}",
+        ) for weight_name in ["w1", "v1", "w2"]]
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in hf_model_weights_iterator(
+                model_name_or_path, cache_dir, load_format, revision,
+                self.config):
+            for param_name, weight_name in expert_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, weight_name)
+                break
+            else:
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 216 - 149
aphrodite/modeling/models/deepseek.py

@@ -22,46 +22,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Deepseek model."""
+from typing import Any, Dict, List, Optional
 
-from typing import Any, Dict, List, Optional, Tuple
-
+import numpy as np
 import torch
 from torch import nn
 from transformers import PretrainedConfig
 
-from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import Attention
-from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
+from aphrodite.modeling.layers.fused_moe import fused_topk
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
-    LinearMethodBase,
-    MergedColumnParallelLinear,
-    ReplicatedLinear,
-    QKVParallelLinear,
-    RowParallelLinear,
-)
+    LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear,
+    ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
+from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding,
-    ParallelLMHead,
-)
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce, )
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-)
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_all_reduce)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (
-    default_weight_loader,
-    hf_model_weights_iterator,
-)
+from aphrodite.modeling.hf_downloader import (default_weight_loader,
+                                              hf_model_weights_iterator)
 from aphrodite.common.sequence import SamplerOutput
 
-KVCache = Tuple[torch.Tensor, torch.Tensor]
-
 
 class DeepseekMLP(nn.Module):
 
@@ -75,18 +62,14 @@ class DeepseekMLP(nn.Module):
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size,
-            [intermediate_size] * 2,
-            bias=False,
-            linear_method=linear_method,
-        )
-        self.down_proj = RowParallelLinear(
-            intermediate_size,
-            hidden_size,
+            hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
-            reduce_results=reduce_results,
-        )
+            linear_method=linear_method)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           linear_method=linear_method,
+                                           reduce_results=reduce_results)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -99,6 +82,39 @@ class DeepseekMLP(nn.Module):
         return x
 
 
+class DeepseekExpertMLP(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+        linear_method: Optional[LinearMethodBase] = None,
+    ) -> None:
+        super().__init__()
+        self.gate_proj = ReplicatedLinear(hidden_size,
+                                          intermediate_size,
+                                          bias=False,
+                                          linear_method=linear_method)
+        self.up_proj = ReplicatedLinear(hidden_size,
+                                        intermediate_size,
+                                        bias=False,
+                                        linear_method=linear_method)
+        self.down_proj = ReplicatedLinear(intermediate_size,
+                                          hidden_size,
+                                          bias=False,
+                                          linear_method=linear_method)
+        self.act_fn = nn.SiLU()
+
+    def forward(self, hidden_states):
+        gate_out, _ = self.gate_proj(hidden_states)
+        gate_out = self.act_fn(gate_out)
+        up_out, _ = self.up_proj(hidden_states)
+        current_hidden_states = gate_out * up_out
+        current_hidden_states, _ = self.down_proj(current_hidden_states)
+        return current_hidden_states
+
+
 class DeepseekMoE(nn.Module):
 
     def __init__(
@@ -112,28 +128,49 @@ class DeepseekMoE(nn.Module):
         self.tp_size = get_tensor_model_parallel_world_size()
         self.n_routed_experts = config.n_routed_experts
         self.top_k = config.num_experts_per_tok
-        if self.tp_size > self.n_routed_experts:
-            raise ValueError(
-                f"Tensor parallel size {self.tp_size} is greater than "
-                f"the number of experts {self.n_routed_experts}.")
-
-        self.experts = nn.ModuleList([
-            DeepseekMLP(
-                hidden_size=config.hidden_size,
-                intermediate_size=config.moe_intermediate_size,
-                hidden_act=config.hidden_act,
+        self.linear_method = linear_method
+        if self.linear_method is None:
+            self.linear_method = UnquantizedLinearMethod()
+
+        if not isinstance(
+                self.linear_method, UnquantizedLinearMethod
+        ) and not self.linear_method.quant_config.support_fused_moe():
+            if self.tp_size > self.n_routed_experts:
+                raise ValueError(
+                    f"Tensor parallel size {self.tp_size} is greater than "
+                    f"the number of experts {self.n_routed_experts}.")
+            # Split experts equally between ranks
+            self.expert_indicies = np.array_split(range(
+                self.n_routed_experts), self.tp_size)[self.rank].tolist()
+            if not self.expert_indicies:
+                raise ValueError(
+                    f"Rank {self.rank} has no experts assigned to it.")
+
+            self.experts = nn.ModuleList([
+                DeepseekExpertMLP(
+                    hidden_size=config.hidden_size,
+                    intermediate_size=config.moe_intermediate_size,
+                    hidden_act=config.hidden_act,
+                    linear_method=linear_method,
+                ) if idx in self.expert_indicies else None
+                for idx in range(self.n_routed_experts)
+            ])
+        else:
+            self.w1 = MergedColumnParallelLinear(
+                config.hidden_size, [config.moe_intermediate_size] * 2,
+                bias=False,
                 linear_method=linear_method,
-                reduce_results=False,
-            ) for idx in range(self.n_routed_experts)
-        ])
-        self.pack_params()
-
-        self.gate = ReplicatedLinear(
-            config.hidden_size,
-            self.n_routed_experts,
-            bias=False,
-            linear_method=None,
-        )
+                num_experts=self.n_routed_experts)
+            self.w2 = RowParallelLinear(config.moe_intermediate_size,
+                                        config.hidden_size,
+                                        bias=False,
+                                        linear_method=linear_method,
+                                        num_experts=self.n_routed_experts)
+
+        self.gate = ReplicatedLinear(config.hidden_size,
+                                     self.n_routed_experts,
+                                     bias=False,
+                                     linear_method=None)
 
         if config.n_shared_experts is not None:
             intermediate_size = (config.moe_intermediate_size *
@@ -146,49 +183,50 @@ class DeepseekMoE(nn.Module):
                 reduce_results=False,
             )
 
-    def pack_params(self):
-        w1 = []
-        w2 = []
-        for expert in self.experts:
-            w1.append(expert.gate_up_proj.weight)
-            w2.append(expert.down_proj.weight)
-        self.w1 = torch._utils._flatten_dense_tensors(w1)
-        w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
-        for data, param in zip(w1s, w1):
-            param.data = data
-        self.w1 = self.w1.view(len(w1), *w1s[0].shape)
-
-        self.w2 = torch._utils._flatten_dense_tensors(w2)
-        w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
-        for data, param in zip(w2s, w2):
-            param.data = data
-
-        self.w2 = self.w2.view(len(w2), *w2s[0].shape)
-
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
-        batch_size, sequence_length, hidden_dim = hidden_states.shape
+        num_tokens, hidden_dim = hidden_states.shape
         hidden_states = hidden_states.view(-1, hidden_dim)
         if self.config.n_shared_experts is not None:
             shared_output = self.shared_experts(hidden_states)
-        # router_logits: (batch * sequence_length, n_experts)
+        # router_logits: (num_tokens, n_experts)
         router_logits, _ = self.gate(hidden_states)
-        final_hidden_states = fused_moe(
-            hidden_states,
-            self.w1,
-            self.w2,
-            router_logits,
-            self.top_k,
-            renormalize=self.config.norm_topk_prob,
-            inplace=True,
-        )
+
+        if not isinstance(
+                self.linear_method, UnquantizedLinearMethod
+        ) and not self.linear_method.quant_config.support_fused_moe():
+            routing_weights, selected_experts = fused_topk(
+                router_logits,
+                self.top_k,
+                renormalize=self.config.norm_topk_prob)
+            final_hidden_states = None
+            for expert_idx in self.expert_indicies:
+                expert_layer = self.experts[expert_idx]
+                expert_mask = (selected_experts == expert_idx)
+                expert_weights = (routing_weights * expert_mask).sum(
+                    dim=-1, keepdim=True)
+
+                current_hidden_states = expert_layer(hidden_states).mul_(
+                    expert_weights)
+                if final_hidden_states is None:
+                    final_hidden_states = current_hidden_states
+                else:
+                    final_hidden_states.add_(current_hidden_states)
+        else:
+            final_hidden_states = self.linear_method.apply_moe_weights(
+                self.w1.linear_weights,
+                self.w2.linear_weights,
+                hidden_states,
+                router_logits,
+                self.top_k,
+                renormalize=self.config.norm_topk_prob,
+            )
 
         if self.config.n_shared_experts is not None:
             final_hidden_states = final_hidden_states + shared_output
         final_hidden_states = tensor_model_parallel_all_reduce(
             final_hidden_states)
 
-        return final_hidden_states.view(batch_size, sequence_length,
-                                        hidden_dim)
+        return final_hidden_states.view(num_tokens, hidden_dim)
 
 
 class DeepseekAttention(nn.Module):
@@ -249,25 +287,22 @@ class DeepseekAttention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = Attention(
-            self.num_heads,
-            self.head_dim,
-            self.scaling,
-            num_kv_heads=self.num_kv_heads,
-        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
-        k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -315,8 +350,8 @@ class DeepseekDecoderLayer(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
     ) -> torch.Tensor:
         # Self Attention
@@ -330,7 +365,7 @@ class DeepseekDecoderLayer(nn.Module):
             positions=positions,
             hidden_states=hidden_states,
             kv_cache=kv_cache,
-            input_metadata=input_metadata,
+            attn_metadata=attn_metadata,
         )
 
         # Fully Connected
@@ -351,10 +386,9 @@ class DeepseekModel(nn.Module):
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
 
-        self.embed_tokens = VocabParallelEmbedding(
-            config.vocab_size,
-            config.hidden_size,
-        )
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size,
+                                                   linear_method=linear_method)
         self.layers = nn.ModuleList([
             DeepseekDecoderLayer(config,
                                  layer_idx,
@@ -367,15 +401,15 @@ class DeepseekModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         residual = None
         for i in range(len(self.layers)):
             layer = self.layers[i]
             hidden_states, residual = layer(positions, hidden_states,
-                                            kv_caches[i], input_metadata,
+                                            kv_caches[i], attn_metadata,
                                             residual)
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
@@ -393,61 +427,73 @@ class DeepseekForCausalLM(nn.Module):
         self.linear_method = linear_method
         self.model = DeepseekModel(config, linear_method)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
-        self.sampler = Sampler(config.vocab_size)
-        self.quant_sampler = QuantSampler(config.vocab_size)
+        self.logits_processor = LogitsProcessor(config.vocab_size,
+                                                config.tokenizer_vocab_size)
+        self.sampler = Sampler()
 
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   input_metadata)
+                                   attn_metadata)
         return hidden_states
 
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
     def sample(
         self,
-        hidden_states: Optional[torch.Tensor],
+        logits: Optional[torch.Tensor],
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        if (self.linear_method is not None
-                and not self.linear_method.quant_config.merge_weight()):
-            next_tokens = self.quant_sampler(self.lm_head(hidden_states),
-                                             sampling_metadata)
-        else:
-            next_tokens = self.sampler(self.lm_head.weight, hidden_states,
-                                       sampling_metadata)
+        next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(
-        self,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-        load_format: str = "auto",
-        revision: Optional[str] = None,
-    ):
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
             ("qkv_proj", "v_proj", "v"),
-            ("gate_up_proj", "gate_proj", 0),
-            ("gate_up_proj", "up_proj", 1),
+            ("mlp.gate_up_proj", "mlp.gate_proj", 0),
+            ("mlp.gate_up_proj", "mlp.up_proj", 1),
+            ("shared_experts.gate_up_proj", "shared_experts.gate_proj", 0),
+            ("shared_experts.gate_up_proj", "shared_experts.up_proj", 1),
         ]
 
+        expert_params_mapping = [
+            # (param_name, weight_name, shard_id, expert_id)
+            ("w1" if weight_name in ["gate_proj", "up_proj"] else "w2",
+             f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
+            for expert_id in range(self.config.n_routed_experts)
+            for weight_name, shard_id in [("gate_proj",
+                                           0), ("up_proj",
+                                                1), ("down_proj", None)]
+        ] if self.linear_method is None or (
+            self.linear_method.quant_config.support_fused_moe()) else []
+
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
                 cache_dir,
                 load_format,
                 revision,
-                fall_back_to_pt=False,
-        ):
+                self.config,
+                fall_back_to_pt=False):
             if "rotary_emb.inv_freq" in name:
                 continue
-            for param_name, weight_name, shard_id in stacked_params_mapping:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
@@ -455,22 +501,43 @@ class DeepseekForCausalLM(nn.Module):
                 if name.endswith(".bias") and name not in params_dict:
                     continue
                 # Skip experts that are not assigned to this worker.
-                if ("mlp.experts." in name or "mlp.shared_experts."
-                        in name) and name not in params_dict:
+                if (("mlp.experts." in name or "mlp.shared_experts." in name)
+                        and name not in params_dict):
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                # Skip experts that are not assigned to this worker.
-                if ("mlp.experts." in name or "mlp.shared_experts."
-                        in name) and name not in params_dict:
-                    continue
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                for (param_name, weight_name, shard_id,
+                     expert_id) in expert_params_mapping:
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    if shard_id is None:
+                        weight_loader(param,
+                                      loaded_weight,
+                                      expert_id=expert_id)
+                    else:
+                        weight_loader(param,
+                                      loaded_weight,
+                                      shard_id,
+                                      expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    # Skip experts that are not assigned to this worker.
+                    if (("mlp.experts." in name
+                         or "mlp.shared_experts." in name)
+                            and name not in params_dict):
+                        continue
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

Some files were not shown because too many files changed in this diff