Răsfoiți Sursa

[0.6.0] Release Candidate (#481)

* chore: skip the driver worker

* chore: bump lmfe version to 0.10.3

* chore: some more marlin cleanups

* chore: deprecation warning for beam search

* feat: support FP8 for DeepSeekV2 MoE

* feat: add fuyu vision model and persimmon language model support

* fix: turn off cutlass scaled_mm for ada lovelace cards

* chore: allow quantizing all layers of deepseek-v2

* fix: build with pylimited api in the docker file

* OpenAI API Refactor (#591)

* feat: massive api server refactoring

* fix: tokenizer endpoint issues

* fix: BatchResponseData body should be optional

* chore: simplify pipeline parallel code in llama

* fix: convert image to RGB by default

* fix: allow getting the chat template from a url

* chore: avoid loading the unused layers and init the VLM up to the required feature space

* chore: enable bias w/ FP8 layers in CUTLASS kernels

* chore: upgrade flashinfer to 0.0.9

* feat: add custom triton cache manager

* chore: add CustomAP interface to UnquantizedFusedMoEMethod

* chore: handle aborted requests for jamba

* fix: minor fix for prompt adapter config

* feat: chat completions tokenization endpoint (#592)

* feat: optimize throughput to 1.4x by using numpy for token padding

* feat: MoE support with Pallas GMM kernel for TPUs

* chore: log spec decoding metrics

* chore: separate kv_scale into k_scale and v_scale

* feat: Asymmetric Tensor Parallel (#594)

* add utils for getting the partition offset and size for current tp rank

* disable asymmetric TP for quants and lora, handle GQA allocation

* the actual splitting work in the linear layers

* padding size for the vocab/lm_head should be optional

* cache engine and spec decode model runner (kwargs only)

* pass the tp_rank to model runners

* llama support

* update dockerfile

* let's not build these for now

* Revert "update dockerfile"

This reverts commit 6dd64089a2da71a17f4c3b79d03855adf92dc84f.

* fix: install wheel and packaging in docker

* fix: admin key arg

* let's try this again

* fix: mamba-ssm installation stuff

* chore: shutdown method for multiproc executor

* chore: log the message queue comms handle

* Port mamba kernels to Aphrodite (#595)

* kernels

* fix interface

* clean up dockerfile

* chore: set seed for dummy weights init

* fix: only create embeddings and lm_head when necessary for PP

* cleanup rocm dockerfile

* fix: some minor typing issues in spec decode

* fix: 4-node crash with PP

* chore: remove multimodal stuff from TPU

* fix: type annotation in worker

* refactor _prepare_model_input_tensor and attn metadata builder for most backends

* move prepare_inputs to the GPU (#596)

* add kernels

Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>

* sampler changes

* refactor the spec decode model runer and worker

---------

Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>

* update all benchmarks (#597)

* feat: add fp8 dynamic per-token quant kernel

* feat: pipeline parallel support for mixtral

* feat: add fp8 channel-wise weight quantization support

* feat: add asymmetric TP support for Qwen2

* feat: add CPU offloading support (#598)

* fix: avoid secondary error in ShmRingBuffer destructor

* feat: add SPMD worker execution using Ray  accelerated DAG

* `enable_gpu_advance_step` -> `allo_gpu_advance_step`

* chore: use the LoRA tokenizer in OpenAI API (#599)

* fix: use paged attention for bloc swapping/copying in flashinfer

* chore: refactor TPU model runner and worker

* chore: improve min_capability checking for `compressed-tensors`

* chore: implement fallback for fp8 channelwise using torch._scaled_mm

* fix: allow using mp executor for pipeline parallel

* fix: make speculative decoding work with per-request seed

* feat: non-uniform quantization via `compressed-tensors` for llama

* fix: the metrics endpoint was not mounted

* fix: raise an error for no draft token case when draft_tp>1

* chore: pass bias to quant_method.apply

* some small performance improvements

* update time since last collection for AsyncMetricsCollector

* fix shared memory bug w/ multi-node

* chore: enable dynamic per-token `fp8`

* docker: install libibverbs by default

* add scale_ub inputs to fp8 dynamic per-token quant

* chore: allow specifying custom Executor

* fix: request abort crashing pipeline parallel

* feat: fbgemm quantization support (#601)

* feat: fbgemm support

* missed this one

* register the quant

* chore: minor AMD fixes

* feat: support fbgemm_fp8 quant on ampere

* fix: input_scale for w8a8 is optional

* fix: channel-wise fp8 marlin

* move `aphrodite.endpoints.openai.chat_utils` -> `aphrodite.endpoints.chat_utils`

* feat: disable logprob serialization to CPU for spec decode

* chore: refactor and decouple phi3v image embedding

* fix: asymmetric TP changes breaking the gptq and awq quants (#602)

* feat: AWQ marlin kernels (#603)

* refactor gptq_marlin kernels to add awq support

* integrate

* fix: short commit hash import error

* feat: initial text-to-text support for Chameleon model

* clean up requirements

* chore: add a wrapper for torch.inference_mode decorator

* fix: `vocab_size` field access in llava

* fix: f-string fixes

* docs: add doc site with example content

* docs: add installation guides

* chore: add contribution guidelines + Code of Conduct (#507)

* add utils; proper project structuring

* add contributing guidelines and CoC

* remove docker

* chore: add contribution guidelines + Code of Conduct (#507)

* add utils; proper project structuring

* add contributing guidelines and CoC

* Refactor prompt processing (#605)

* wip

* finish up the refactor

* fix: use int64_t for indices in fp8 kernels

* feat: support loading lora adapters directly from HF

* chore: modularize prepare input and attn metadata builder

* fix: fbgemm_fp8 when modules_to_not_convert=None

* chore: automatically enable chunked prefill if model has large seqlen

* chore: move some verbose logs to debug

* chore: further improve logging

* feat: support FP8 KV Cache scales from compressed-tensors

* feat: script for multi-node cluster setup

* chore: manage all http connections in one place

* feat: allow image inputs with chameleon models

* fix: support ignore patterns in model loader

* fix: beta value in gelu_tanh kernel being divided by 0.5

* fix: some naming issues

* chore: add ignored layers for fp8 quant

* chore: add usage data in each chunk for serving_chat

* bump transformers

* fix: cache spec decode metrics when they get collected

* feat: support loading pre-quanted bnb checkpoints

* fix: flashinfer cuda graph capture with pipeline parallel

* fix: token padding for chameleon

* fix: `ignore_patterns` -> `ignore_file_pattern` for modelscope

* fix: miscalculated latency leading to ttft inaccuracy

* chore: split `run_server` into `build_server` and `run_server`

* chore: add fp8 support to `reshape_and_cache_flash`

* chore: tweaks to model/runner builder developer APIs

* chore: bump transformers

* fix: zmq hangs with large requests

* chore: represent tokens with identifiable strings

* feat: add support for MiniCPM-V

* fix: decode tokens w/ CUDA graphs and graps with flashinfer

* fix: allow passing -q {gptq,awq}_marlin as the arg

* fix: encoding format for embedding example

* fix: add image placeholder for openai server for minicpmv

* feat: `fp8-marlin` channel-wise quant via `compressed-tensors`

* fix: `kv_cache_dtype=fp8` without scales for fp8 checkpoints

* fix: prevent possible data race by adding sync

* fix: nulltr channelwise scales when loading wNa16 models

* fix: define self.forward_dag before init_workers_ray

* fix: replicatedlinear weight loading

* fix: pass signal from the main thread

* chore: use array to speedup padding

* feat: add nemotron HF support (#606)

* fix: promote another index in fp8 kernel to int64_t

* chore: minor simplifications for Dockerfile.rocm

* chore: allow initializing TPU in initialize_ray_cluster

* feat: tensor parallelism for CPU backend

* chore: simplify squared relu

* fix: disable enforce_eager for bnb

* feat: support collective comms in XLA devices, e.g. TPUs

* chore: factor out the code for running uvicorn

* fix: illegal mem access for fp8 l3.1 405b

* fix: torch nightly version for rocm dockerfile

* fix: do not enable chunked prefill and prefix caching for jamba

* feat: add support for head_size of 120

* feat: tensor parallelism for TPU with ray

* fix: torch.set_num_threads() in multiproc_gpu_executor

* chore: consolidate all vision examples into one file

* feat: add blip-2 support

* chore: reduce XLA compile times

* fix: better logging for memory profiling

* chore: perform allreduce in fp32 for marlin, better logging

* fix: add nemotron to PP_SUPPORTED_MODELS

* fix: pass cutlass_fp8_supported correctly for fbgemm_fp8

* feat: add internvl support

* chore: tune fp8 kernels for ada lovelace cards

* fix: reduce unnecessary compute when logprobs=None

* fix: deprecation warnings in squeezellm quant_cuda_kernel

* chore: enable tpu tensor parallel in async engine

* fix: remove timm as a hardcoded requirement

* chore: make triton fully optional

* feat: add allowed_token_ids

* fix: unused variables in awq gemm kernel

* fix: divide-by-zero warnings in marlin kernels

* chore: tune int8 kernels for ada lovelace

* fix: greedy decoding in TPU

* fix: paligemma mmp

* fix: seeded gens with pipeline parallel

* fix: compiler warnings for _C and _moe

* chore: bump openvino toolkit to pre-release

* fix: remove scaled_fp8_quant_kernel padding footgun

* fix: massively improve throughput with high number of prompts

* fix: remove artifact

* chore: add punica sizes for mistral nemo

* fix: conditionally import outlines.caching

* fix: wrap all outlines imports

* chore: sort args (#608)

* chore: sort args

* Update aphrodite/modeling/guided_decoding/outlines_logits_processors.py

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>

* Update aphrodite/engine/args_tools.py

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>

* Update aphrodite/engine/args_tools.py

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>

* fix: Device Options category

* fix: Load Options category with load_format, dtype and ignore_patterns

* fix: API Options -> Inference Options with seed, served_model_name moved to model options

* fix: categorize based off of config.py and wrong category names

* fix: move `model` arg to the top

* fix: add missing `max_seq_lens_to_capture` arg

* chore: sort the device arg

* chore: move `model` arg to the top of the argparser

* chore: remove old comment

---------

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: AlpinDale <alpindale@gmail.com>

* fix: formatting

* feat: add yaml config parsing (#610)

* feat: add yaml config parsing

* fix: prompt adapters

* chore: add isort and refactor formatting script and utils

* fix: broadcasting logic for multi_modal_kwargs

* fix: set readonly=True for non-root TPU devices

* fix: logit processor exceeding vocab size

* feat: support for QQQ W4A8 quantization (#612)

* feat: add qqq marlin kernels

* integrate qqq quant

* fix: cleanup minicpm-v and port na_vit model

* fix: feature size calculation for Llava-next

* feat: use FusedMoE for jamba

* feat: allow loading specific layer numbers per device

* fix: fp8 marlin and cpu offloading with fp8 marlin

* chore: enable fp8 cutlass for ada lovelace

* chore; tune cutlass int8 kernels for sm_75

* feat: Triton Kernels for Punica (#613)

* feat: replace CUDA kernels w/ triton for lora

Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>

* fix: __init__ in ops

* fix: replicated linear layer support for LoRA

* cleanup and relax the conditions for vocab_size and rank

* fix: add consolidated* to ignore_patterns

---------

Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>

* chore: add pipeline parallel support for Qwen

* fix: don't use torch.generator() for TPU

* fix: skip loading lm_head for tie_word_embeddings models

* chore: optimize PP comm by replacing send with partial send + allgather

* fix: set a default max_tokens for OAI requests

* chore: optimize scheduler and remove policy

* chore: bump torch to 2.4.0

* bump to torch 2.4.0, add aphrodite_flash_attn (#614)

* fix: RMSNorm forward in InternViT attention qk_layernorm

* 'fix: lower gemma's unloaded_params exception to warning

* feat: support logits soft capping with flash attention backend

* chore: optimize get_seqs

* fix: input shape for flashinfer prefill wrapper

* fix: remove error_on_invalid_device_count_status

* fix: remove unused code in sampler

* build: update torch to 2.4.0 for cpu

* kernels: disambiguate quantized types via a new ScalarType

Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>

* chore: pipeline parallel with Ray accelerated dag

* fix: use loopback address for single node again

* chore: add env var to enable torch.compile

* revert: incorrect nightly build

* feat: add RPC server and client via ZMQ (#615)

* feat: add RPC server and client

Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>

* add async engine protocols

* add new methods to sync and async engine

* producer utils

* migrate serving engine, embedding and tokenization to rpc

* migrate text completions

* migrate chat completions

* migrate logits processors api

* migrate api server

* forgot the arg

* minor naming issues

---------

Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>

* refactor: factor out chat message parsing

* refactor: add has_prefix_cache_hit flag to FlashAttentionMetadataBuilder

* feat: add guided decoding to LLM

* chore: simplify output processing with shortcut for non-parallel sampling and non-beam search usecase (#616)

* refactor: minicpmv and port Idefix2VisionTransformer

* refactor: factor out code for running uvicorn again

* feat: port SiglipVisionModel from transformers

* chore: add proper logging for spec decoding verification

* fix: support flashinfer for draft model runner

* fix: use ipv4 localhost form for zmq bind

* fix: use args.trust_remote_code

* chore: update cutlass to 3.5.1

* fix: specify device when loading lora and embedding tensors

* feat: non-blocking transfer in prepare_input

* feat: re-add GGUF (#600)

* refactor gguf kernels

* fix: incorrect filename for vecdotq header

* finish up the re-impl

* add requirements

* minor CI fixes

* ci: a few more ignores

* ci: remove clang-format

* ci: take one of fixing lint issues

* ci: codespell fixes

* ci: remove yapf

* ci: remove yapf from the formatting script

* ci: remove isort

* chore: minor cleanups

* fix: allow loading GGUF model without .gguf extension

* fix: cpu offloading with gptq

* docs: finalize User & Developer Documentation for Release Candidate (#618)

* add getting started page

* add debugging tips

* add openai docs

* add distributed guide

* add production metrics and model support matrix

* add guide for adding new models

* huge update

* add vlm usage docs

* ci: add action for deploying docs

* docs: fix typos

* chore: refactor wheel build script

* bump version to 0.6.0

---------

Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: Ahmed <mail@ahme.dev>
Co-authored-by: ewof <elwolf6@protonmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
AlpinDale 6 luni în urmă
părinte
comite
f1d0b77c92
100 a modificat fișierele cu 17704 adăugiri și 4571 ștergeri
  1. 26 0
      .clang-format
  2. 3 0
      .dockerignore
  3. 0 0
      .env
  4. 64 0
      .github/workflows/deploy.yml
  5. 5 3
      .github/workflows/publish.yml
  6. 4 4
      .github/workflows/ruff.yml
  7. 0 31
      .github/workflows/yapf.yml
  8. 3 2
      .gitignore
  9. 94 219
      CMakeLists.txt
  10. 136 0
      Dockerfile
  11. 41 0
      Dockerfile.cpu
  12. 0 0
      Dockerfile.neuron
  13. 32 0
      Dockerfile.openvino
  14. 22 0
      Dockerfile.ppc64le
  15. 180 0
      Dockerfile.rocm
  16. 29 0
      Dockerfile.tpu
  17. 22 0
      Dockerfile.xpu
  18. 7 1
      MANIFEST.in
  19. 14 6
      aphrodite/__init__.py
  20. 176 0
      aphrodite/_core_ext.py
  21. 620 0
      aphrodite/_custom_ops.py
  22. 241 0
      aphrodite/_ipex_ops.py
  23. 0 0
      aphrodite/adapter_commons/__init__.py
  24. 14 0
      aphrodite/adapter_commons/layers.py
  25. 102 0
      aphrodite/adapter_commons/models.py
  26. 25 0
      aphrodite/adapter_commons/request.py
  27. 90 0
      aphrodite/adapter_commons/utils.py
  28. 36 0
      aphrodite/adapter_commons/worker_manager.py
  29. 4 6
      aphrodite/attention/__init__.py
  30. 87 39
      aphrodite/attention/backends/abstract.py
  31. 429 0
      aphrodite/attention/backends/blocksparse_attn.py
  32. 397 97
      aphrodite/attention/backends/flash_attn.py
  33. 550 0
      aphrodite/attention/backends/flashinfer.py
  34. 380 0
      aphrodite/attention/backends/ipex_attn.py
  35. 100 0
      aphrodite/attention/backends/openvino.py
  36. 248 0
      aphrodite/attention/backends/pallas.py
  37. 296 142
      aphrodite/attention/backends/rocm_flash_attn.py
  38. 117 57
      aphrodite/attention/backends/torch_sdpa.py
  39. 232 0
      aphrodite/attention/backends/utils.py
  40. 509 111
      aphrodite/attention/backends/xformers.py
  41. 76 10
      aphrodite/attention/layer.py
  42. 0 0
      aphrodite/attention/ops/blocksparse_attention/__init__.py
  43. 422 0
      aphrodite/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
  44. 236 0
      aphrodite/attention/ops/blocksparse_attention/interface.py
  45. 242 0
      aphrodite/attention/ops/blocksparse_attention/utils.py
  46. 122 0
      aphrodite/attention/ops/ipex_attn.py
  47. 68 37
      aphrodite/attention/ops/paged_attn.py
  48. 113 51
      aphrodite/attention/ops/prefix_prefill.py
  49. 28 13
      aphrodite/attention/ops/triton_flash_attn.py
  50. 156 34
      aphrodite/attention/selector.py
  51. 0 41
      aphrodite/common/block.py
  52. 487 271
      aphrodite/common/config.py
  53. 167 0
      aphrodite/common/connections.py
  54. 71 9
      aphrodite/common/logger.py
  55. 97 36
      aphrodite/common/outputs.py
  56. 19 0
      aphrodite/common/pooling_params.py
  57. 66 68
      aphrodite/common/sampling_params.py
  58. 461 182
      aphrodite/common/sequence.py
  59. 636 97
      aphrodite/common/utils.py
  60. 13 191
      aphrodite/distributed/communication_op.py
  61. 163 0
      aphrodite/distributed/device_communicators/cuda_wrapper.py
  62. 187 170
      aphrodite/distributed/device_communicators/custom_all_reduce.py
  63. 241 0
      aphrodite/distributed/device_communicators/custom_all_reduce_utils.py
  64. 145 246
      aphrodite/distributed/device_communicators/pynccl.py
  65. 0 68
      aphrodite/distributed/device_communicators/pynccl_utils.py
  66. 275 0
      aphrodite/distributed/device_communicators/pynccl_wrapper.py
  67. 478 0
      aphrodite/distributed/device_communicators/shm_broadcast.py
  68. 31 0
      aphrodite/distributed/device_communicators/tpu_communicator.py
  69. 993 173
      aphrodite/distributed/parallel_state.py
  70. 33 79
      aphrodite/distributed/utils.py
  71. 217 0
      aphrodite/endpoints/chat_utils.py
  72. 190 10
      aphrodite/endpoints/cli.py
  73. 5 4
      aphrodite/endpoints/kobold/klite.embd
  74. 487 79
      aphrodite/endpoints/llm.py
  75. 40 0
      aphrodite/endpoints/logger.py
  76. 263 164
      aphrodite/endpoints/openai/api_server.py
  77. 82 42
      aphrodite/endpoints/openai/args.py
  78. 0 144
      aphrodite/endpoints/openai/embeddings.py
  79. 83 0
      aphrodite/endpoints/openai/logits_processors.py
  80. 510 216
      aphrodite/endpoints/openai/protocol.py
  81. 40 0
      aphrodite/endpoints/openai/rpc/__init__.py
  82. 236 0
      aphrodite/endpoints/openai/rpc/client.py
  83. 205 0
      aphrodite/endpoints/openai/rpc/server.py
  84. 159 0
      aphrodite/endpoints/openai/run_batch.py
  85. 331 125
      aphrodite/endpoints/openai/serving_chat.py
  86. 236 132
      aphrodite/endpoints/openai/serving_completions.py
  87. 171 0
      aphrodite/endpoints/openai/serving_embedding.py
  88. 328 174
      aphrodite/endpoints/openai/serving_engine.py
  89. 136 0
      aphrodite/endpoints/openai/serving_tokenization.py
  90. 664 225
      aphrodite/engine/aphrodite_engine.py
  91. 585 326
      aphrodite/engine/args_tools.py
  92. 449 189
      aphrodite/engine/async_aphrodite.py
  93. 189 0
      aphrodite/engine/async_timeout.py
  94. 476 178
      aphrodite/engine/metrics.py
  95. 21 12
      aphrodite/engine/output_processor/interfaces.py
  96. 38 17
      aphrodite/engine/output_processor/multi_step.py
  97. 84 31
      aphrodite/engine/output_processor/single_step.py
  98. 24 4
      aphrodite/engine/output_processor/stop_checker.py
  99. 11 5
      aphrodite/engine/output_processor/util.py
  100. 83 0
      aphrodite/engine/protocol.py

+ 26 - 0
.clang-format

@@ -0,0 +1,26 @@
+BasedOnStyle: Google
+UseTab: Never
+IndentWidth: 2
+ColumnLimit: 80
+
+# Force pointers to the type for C++.
+DerivePointerAlignment: false
+PointerAlignment: Left
+
+# Reordering #include statements can (and currently will) introduce errors
+SortIncludes: false
+
+# Style choices
+AlignConsecutiveAssignments: false
+AlignConsecutiveDeclarations: false
+IndentPPDirectives: BeforeHash
+
+IncludeCategories:
+  - Regex:           '^<'
+    Priority:        4
+  - Regex:           '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
+    Priority:        3
+  - Regex:           '^"(qoda|\.\.)/'
+    Priority:        2
+  - Regex:           '.*'
+    Priority:        1

+ 3 - 0
.dockerignore

@@ -0,0 +1,3 @@
+conda/
+build/
+venv/

+ 0 - 0
docker/.env → .env


+ 64 - 0
.github/workflows/deploy.yml

@@ -0,0 +1,64 @@
+# Sample workflow for building and deploying a VitePress site to GitHub Pages
+#
+name: Deploy VitePress site to Pages
+
+on:
+  # Runs on pushes targeting the `main` branch. Change this to `master` if you're
+  # using the `master` branch as the default branch.
+  push:
+    branches: [main]
+
+  # Allows you to run this workflow manually from the Actions tab
+  workflow_dispatch:
+
+# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
+permissions:
+  contents: read
+  pages: write
+  id-token: write
+
+# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued.
+# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete.
+concurrency:
+  group: pages
+  cancel-in-progress: false
+
+jobs:
+  # Build job
+  build:
+    runs-on: ubuntu-latest
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v4
+        with:
+          fetch-depth: 0 # Not needed if lastUpdated is not enabled
+      - uses: pnpm/action-setup@v3 # Uncomment this if you're using pnpm
+      # - uses: oven-sh/setup-bun@v1 # Uncomment this if you're using Bun
+      - name: Setup Node
+        uses: actions/setup-node@v4
+        with:
+          node-version: 20
+          cache: pnpm # or pnpm / yarn
+      - name: Setup Pages
+        uses: actions/configure-pages@v4
+      - name: Install dependencies
+        run: pnpm install # or pnpm install / yarn install / bun install
+      - name: Build with VitePress
+        run: pnpm docs:build # or pnpm docs:build / yarn docs:build / bun run docs:build
+      - name: Upload artifact
+        uses: actions/upload-pages-artifact@v3
+        with:
+          path: docs/.vitepress/dist
+
+  # Deployment job
+  deploy:
+    environment:
+      name: github-pages
+      url: ${{ steps.deployment.outputs.page_url }}
+    needs: build
+    runs-on: ubuntu-latest
+    name: Deploy
+    steps:
+      - name: Deploy to GitHub Pages
+        id: deployment
+        uses: actions/deploy-pages@v4

+ 5 - 3
.github/workflows/publish.yml

@@ -48,9 +48,9 @@ jobs:
       fail-fast: false
       matrix:
           os: ['ubuntu-20.04']
-          python-version: ['3.8', '3.9', '3.10', '3.11']
-          pytorch-version: ['2.3.0']  # Must be the most recent version that meets requirements-cuda.txt.
-          cuda-version: ['11.8', '12.1']
+          python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
+          pytorch-version: ['2.4.0']  # Must be the most recent version that meets requirements-cuda.txt.
+          cuda-version: ['12.4', '12.1', '11.8']
 
     steps:
       - name: Checkout
@@ -76,6 +76,8 @@ jobs:
 
       - name: Build wheel
         shell: bash
+        env:
+          CMAKE_BUILD_TYPE: Release
         run: |
           bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
           wheel_name=$(ls dist/*whl | xargs -n 1 basename)

+ 4 - 4
.github/workflows/ruff.yml

@@ -15,7 +15,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: ["3.10"]
+        python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
     steps:
     - uses: actions/checkout@v2
     - name: Set up Python ${{ matrix.python-version }}
@@ -25,10 +25,10 @@ jobs:
     - name: Install dependencies
       run: |
         python -m pip install --upgrade pip
-        pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1
+        pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1
     - name: Analysing the code with ruff
       run: |
-        ruff aphrodite tests
+        ruff .
     - name: Spelling check with codespell
       run: |
-         codespell --toml pyproject.toml
+        codespell --toml pyproject.toml

+ 0 - 31
.github/workflows/yapf.yml

@@ -1,31 +0,0 @@
-name: yapf
-
-on:
-  # Trigger the workflow on push or pull request,
-  # but only for the main branch
-  push:
-    branches:
-      - main
-  pull_request:
-    branches:
-      - main
-jobs:
-  yapf:
-    runs-on: ubuntu-latest
-    strategy:
-      matrix:
-        python-version: ["3.10"]
-    steps:
-    - uses: actions/checkout@v2
-    - name: Set up Python ${{ matrix.python-version }}
-      uses: actions/setup-python@v2
-      with:
-        python-version: ${{ matrix.python-version }}
-    - name: Install dependencies
-      run: |
-        python -m pip install --upgrade pip
-        pip install yapf==0.32.0
-        pip install toml==0.10.2
-    - name: Running yapf
-      run: |
-        yapf --diff --recursive aphrodite tests

+ 3 - 2
.gitignore

@@ -6,12 +6,12 @@ repos
 *.so
 .conda
 build
-dist*
 .VSCodeCounter
 conda/
 umamba.exe
 bin/
 *.whl
+aphrodite/commit_id.py
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]
@@ -198,4 +198,5 @@ _build/
 
 kv_cache_states/*
 quant_params/*
-.ruff_cache/
+.ruff_cache/
+images/

+ 94 - 219
CMakeLists.txt

@@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21)
 
 project(aphrodite_extensions LANGUAGES CXX)
 
-option(APHRODITE_TARGET_DEVICE "Target device backend for Aphrodite" "cuda")
+# CUDA by default, can be overridden by using -DAPHRODITE_TARGET_DEVICE=... (used by setup.py)
+set(APHRODITE_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for Aphrodite")
 
 message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
 message(STATUS "Target device: ${APHRODITE_TARGET_DEVICE}")
@@ -13,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
 # Supported python versions.  These versions will be searched in order, the
 # first match will be selected.  These should be kept in sync with setup.py.
 #
-set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
+set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
 
 # Supported NVIDIA architectures.
 set(CUDA_SUPPORTED_ARCHS "6.0;6.1;7.0;7.5;8.0;8.6;8.9;9.0")
@@ -31,9 +32,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
 # requirements.txt files and should be kept consistent.  The ROCm torch
 # versions are derived from Dockerfile.rocm
 #
-set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
-set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
-set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
+set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
 
 #
 # Try to find python package with an executable that exactly matches
@@ -67,17 +67,37 @@ endif()
 find_package(Torch REQUIRED)
 
 #
-# Normally `torch.utils.cpp_extension.CUDAExtension` would add
-# `libtorch_python.so` for linking against an extension. Torch's cmake
-# configuration does not include this library (presumably since the cmake
-# config is used for standalone C++ binaries that link against torch).
-# The `libtorch_python.so` library defines some of the glue code between
-# torch/python via pybind and is required by APHRODITE extensions for this
-# reason. So, add it by manually using `append_torchlib_if_found` from
-# torch's cmake setup.
+# Add the `default` target which detects which extensions should be
+# built based on platform/architecture.  This is the same logic that
+# setup.py uses to select which extensions should be built and should
+# be kept in sync.
+#
+# The `default` target makes direct use of cmake easier since knowledge
+# of which extensions are supported has been factored in, e.g.
+#
+# mkdir build && cd build
+# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
+# cmake --build . --target default
 #
-find_library(torch_python_LIBRARY torch_python PATHS
-  "${TORCH_INSTALL_PREFIX}/lib")
+add_custom_target(default)
+message(STATUS "Enabling core extension.")
+
+# Define _core_C extension
+#  built for (almost) every target platform, (excludes TPU and Neuron)
+
+set(APHRODITE_EXT_SRC
+  "kernels/core/torch_bindings.cpp")
+
+define_gpu_extension_target(
+  _core_C
+  DESTINATION aphrodite
+  LANGUAGE CXX
+  SOURCES ${APHRODITE_EXT_SRC}
+  COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
+  USE_SABI 3
+  WITH_SOABI)
+
+add_dependencies(default _core_C)
 
 #
 # Forward the non-CUDA device extensions to external CMake scripts.
@@ -87,7 +107,7 @@ if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
     if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
         include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
     else()
-        message(FATAL_ERROR "Unsupported Aphrodite target device: ${APHRODITE_TARGET_DEVICE}")
+        return()
     endif()
     return()
 endif()
@@ -111,18 +131,11 @@ elseif(HIP_FOUND)
   # .hip extension automatically, HIP must be enabled explicitly.
   enable_language(HIP)
 
-  # ROCm 5.x
-  if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
-      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
-    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
-      "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
-  endif()
-
-  # ROCm 6.x
-  if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
-      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
-    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
-      "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
+  # ROCm 5.X and 6.X
+  if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
+      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+      message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
+      "expected for ROCm build, saw ${Torch_VERSION} instead.")
   endif()
 else()
   message(FATAL_ERROR "Can't find CUDA or HIP installation.")
@@ -152,7 +165,7 @@ if(NVCC_THREADS AND APHRODITE_GPU_LANG STREQUAL "CUDA")
 endif()
 
 #
-# Define extension targets
+# Define other extension targets
 #
 
 #
@@ -165,56 +178,69 @@ set(APHRODITE_EXT_SRC
   "kernels/pos_encoding_kernels.cu"
   "kernels/activation_kernels.cu"
   "kernels/layernorm_kernels.cu"
+  "kernels/quantization/squeezellm/quant_cuda_kernel.cu"
+  "kernels/quantization/gptq/q_gemm.cu"
+  "kernels/quantization/compressed_tensors/int8_quant_kernels.cu"
+  "kernels/quantization/fp8/common.cu"
   "kernels/cuda_utils_kernels.cu"
   "kernels/moe/align_block_size_kernel.cu"
-  "kernels/pybind.cpp")
+  "kernels/prepare_inputs/advance_step.cu"
+  "kernels/torch_bindings.cpp")
 
 if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-  list(APPEND APHRODITE_EXT_SRC
-    "kernels/all_reduce/custom_all_reduce.cu")
-endif()
-
-define_gpu_extension_target(
-  _C
-  DESTINATION aphrodite
-  LANGUAGE ${APHRODITE_GPU_LANG}
-  SOURCES ${APHRODITE_EXT_SRC}
-  COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
-  ARCHITECTURES ${APHRODITE_GPU_ARCHES}
-  WITH_SOABI)
+  include(FetchContent)
+  SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
+  FetchContent_Declare(
+        cutlass
+        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
+        # CUTLASS 3.5.1
+        GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 
+  )
+  FetchContent_MakeAvailable(cutlass)
 
-
-#
-# _quant_C extension
-#
-
-set (APHRODITE_QUANT_EXT_SRC
-  "kernels/quantization/gptq/q_gemm.cu"
-  "kernels/quantization/squeezellm/quant_cuda_kernel.cu"
-  "kernels/quantization/exl2/q_matrix.cu"
-  "kernels/quantization/exl2/q_gemm_exl2.cu"
-  "kernels/quantization/quant_ops.cpp")
-
-if(APHRODITE_GPU_LANG STREQUAL "CUDA")
-  list(APPEND APHRODITE_QUANT_EXT_SRC
-    "kernels/quantization/aqlm/aqlm_cuda_entry.cpp"
-    "kernels/quantization/aqlm/aqlm_cuda_kernel.cu"
+  list(APPEND APHRODITE_EXT_SRC
+    "kernels/mamba/mamba_ssm/selective_scan_fwd.cu"
+    "kernels/mamba/causal_conv1d/causal_conv1d.cu"
+    "kernels/quantization/aqlm/gemm_kernels.cu"
     "kernels/quantization/awq/gemm_kernels.cu"
-    "kernels/quantization/bitsandbytes/int4_fp16_gemm_kernels.cu"
-    "kernels/quantization/bitsandbytes/format.cu"
-    "kernels/quantization/bitsandbytes/gemm_s4_f16.cu"
+    "kernels/quantization/quip/origin_order.cu"
+    "kernels/quantization/marlin/dense/marlin_cuda_kernel.cu"
+    "kernels/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
+    "kernels/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
+    "kernels/quantization/gptq_marlin/gptq_marlin.cu"
+    "kernels/quantization/gptq_marlin/gptq_marlin_repack.cu"
     "kernels/quantization/gguf/gguf_kernel.cu"
-    "kernels/quantization/marlin/marlin_cuda_kernel.cu"
-    "kernels/quantization/quip/origin_order.cu")
+    "kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
+    "kernels/quantization/fp8/fp8_marlin.cu"
+    "kernels/all_reduce/custom_all_reduce.cu"
+    "kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu"
+    "kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
+    "kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
+
+  #
+  # The CUTLASS kernels for Hopper require sm90a to be enabled.
+  # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
+  # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
+  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
+    set_source_files_properties(
+          "kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
+          "kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
+          PROPERTIES
+          COMPILE_FLAGS
+          "-gencode arch=compute_90a,code=sm_90a -Wno-psabi")
+  endif()
+
 endif()
 
 define_gpu_extension_target(
-  _quant_C
+  _C
   DESTINATION aphrodite
   LANGUAGE ${APHRODITE_GPU_LANG}
-  SOURCES ${APHRODITE_QUANT_EXT_SRC}
+  SOURCES ${APHRODITE_EXT_SRC}
   COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
   ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
+  USE_SABI 3
   WITH_SOABI)
 
 #
@@ -222,7 +248,7 @@ define_gpu_extension_target(
 #
 
 set(APHRODITE_MOE_EXT_SRC
-  "kernels/moe/moe_ops.cpp"
+  "kernels/moe/torch_bindings.cpp"
   "kernels/moe/softmax.cu")
 
 define_gpu_extension_target(
@@ -232,164 +258,13 @@ define_gpu_extension_target(
   SOURCES ${APHRODITE_MOE_EXT_SRC}
   COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
   ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+  USE_SABI 3
   WITH_SOABI)
 
-#
-# _punica_C extension
-#
-
-set(APHRODITE_PUNICA_EXT_SRC
-  "kernels/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
-  "kernels/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
-  "kernels/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
-  "kernels/punica/punica_ops.cc")
-
-#
-# Copy GPU compilation flags+update for punica
-#
-set(APHRODITE_PUNICA_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
-list(REMOVE_ITEM APHRODITE_PUNICA_GPU_FLAGS
-  "-D__CUDA_NO_HALF_OPERATORS__"
-  "-D__CUDA_NO_HALF_CONVERSIONS__"
-  "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
-  "-D__CUDA_NO_HALF2_OPERATORS__")
-
-#
-# Filter out CUDA architectures < 8.0 for punica.
-#
-if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
-  set(APHRODITE_PUNICA_GPU_ARCHES)
-  foreach(ARCH ${APHRODITE_GPU_ARCHES})
-    string_to_ver(CODE_VER ${ARCH})
-    if (CODE_VER GREATER_EQUAL 8.0)
-      list(APPEND APHRODITE_PUNICA_GPU_ARCHES ${ARCH})
-    endif()
-  endforeach()
-  message(STATUS "Punica target arches: ${APHRODITE_PUNICA_GPU_ARCHES}")
-endif()
-
-if (APHRODITE_PUNICA_GPU_ARCHES)
-  define_gpu_extension_target(
-    _punica_C
-    DESTINATION aphrodite
-    LANGUAGE ${APHRODITE_GPU_LANG}
-    SOURCES ${APHRODITE_PUNICA_EXT_SRC}
-    COMPILE_FLAGS ${APHRODITE_PUNICA_GPU_FLAGS}
-    ARCHITECTURES ${APHRODITE_PUNICA_GPU_ARCHES}
-    WITH_SOABI)
-else()
-  message(WARNING "Unable to create _punica_C target because none of the "
-    "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 8.0")
-endif()
-
-#
-# _hadamard_C extension
-#
-
-set(APHRODITE_HADAMARD_EXT_SRC
-  "kernels/hadamard/fast_hadamard_transform.cpp"
-  "kernels/hadamard/fast_hadamard_transform_cuda.cu")
-
-#
-# Copy GPU compilation flags+update for hadamard
-#
-set(APHRODITE_HADAMARD_GPU_FLAGS ${APHRODITE_GPU_FLAGS})
-list(APPEND APHRODITE_HADAMARD_GPU_FLAGS
-  "-U__CUDA_NO_HALF_OPERATORS__"
-  "-U__CUDA_NO_HALF_CONVERSIONS__"
-  "-U__CUDA_NO_BFLOAT16_OPERATORS__"
-  "-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
-  "-U__CUDA_NO_BFLOAT162_OPERATORS__"
-  "-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
-  "--expt-relaxed-constexpr"
-  "--expt-extended-lambda"
-  "--use_fast_math"
-  "--ptxas-options=-v"
-  "-lineinfo")
-
-#
-# Filter out CUDA architectures < 7.0 for hadamard.
-#
-if (${APHRODITE_GPU_LANG} STREQUAL "CUDA")
-  set(APHRODITE_HADAMARD_GPU_ARCHES)
-  foreach(ARCH ${APHRODITE_GPU_ARCHES})
-    string_to_ver(CODE_VER ${ARCH})
-    if (CODE_VER GREATER_EQUAL 6.0)
-      list(APPEND APHRODITE_HADAMARD_GPU_ARCHES ${ARCH})
-    endif()
-  endforeach()
-  message(STATUS "Hadamard target arches: ${APHRODITE_HADAMARD_GPU_ARCHES}")
-endif()
-
-if (APHRODITE_HADAMARD_GPU_ARCHES)
-  define_gpu_extension_target(
-    _hadamard_C
-    DESTINATION aphrodite
-    LANGUAGE ${APHRODITE_GPU_LANG}
-    SOURCES ${APHRODITE_HADAMARD_EXT_SRC}
-    COMPILE_FLAGS ${APHRODITE_HADAMARD_GPU_FLAGS}
-    ARCHITECTURES ${APHRODITE_HADAMARD_GPU_ARCHES}
-    WITH_SOABI)
-else()
-  message(WARNING "Unable to create _hadamard_C target because none of the "
-    "requested architectures (${APHRODITE_GPU_ARCHES}) are supported, i.e. >= 6.0")
-endif()
-
-#
-# Add the `default` target which detects which extensions should be
-# built based on platform/architecture.  This is the same logic that
-# setup.py uses to select which extensions should be built and should
-# be kept in sync.
-#
-# The `default` target makes direct use of cmake easier since knowledge
-# of which extensions are supported has been factored in, e.g.
-#
-# mkdir build && cd build
-# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
-# cmake --build . --target default
-#
-add_custom_target(default)
-
 if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
   message(STATUS "Enabling C extension.")
   add_dependencies(default _C)
-endif()
 
-if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   message(STATUS "Enabling moe extension.")
   add_dependencies(default _moe_C)
-
-  # Enable punica if -DAPHRODITE_INSTALL_PUNICA_KERNELS=ON or
-  # APHRODITE_INSTALL_PUNICA_KERNELS is set in the environment and
-  # there are supported target arches.
-  if (APHRODITE_QUANT_EXT_SRC AND
-      (ENV{APHRODITE_INSTALL_QUANT_KERNELS} OR APHRODITE_INSTALL_QUANT_KERNELS))
-    message(STATUS "Enabling quant extension.")
-    add_dependencies(default _quant_C)
-  endif()
-  if (APHRODITE_PUNICA_GPU_ARCHES AND
-      (ENV{APHRODITE_INSTALL_PUNICA_KERNELS} OR APHRODITE_INSTALL_PUNICA_KERNELS))
-    message(STATUS "Enabling punica extension.")
-    add_dependencies(default _punica_C)
-  endif()
-  if (APHRODITE_HADAMARD_GPU_ARCHES AND
-      (ENV{APHRODITE_INSTALL_HADAMARD_KERNELS} OR APHRODITE_INSTALL_HADAMARD_KERNELS))
-    message(STATUS "Enabling hadamard extension.")
-    add_dependencies(default _hadamard_C)
-  endif()
-endif()
+endif()

+ 136 - 0
Dockerfile

@@ -0,0 +1,136 @@
+# The Aphrodite Dockerfile is used to construct Aphrodite image that can be directly used
+# to run the OpenAI compatible server.
+
+ARG CUDA_VERSION=12.4.1
+#################### BASE BUILD IMAGE ####################
+# prepare basic build environment
+FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
+
+ARG CUDA_VERSION=12.4.1
+ARG PYTHON_VERSION=3
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
+    && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
+    && apt-get update -y \
+    && apt-get install -y ccache software-properties-common \
+    && add-apt-repository ppa:deadsnakes/ppa \
+    && apt-get update -y \
+    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python3-pip \
+    && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
+    && python3 --version \
+    && python3 -m pip --version
+
+RUN apt-get update -y \
+    && apt-get install -y python3-pip git curl libibverbs-dev
+
+# Workaround for https://github.com/openai/triton/issues/2507 and
+# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
+# this won't be needed for future versions of this docker image
+# or future versions of triton.
+RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
+
+WORKDIR /workspace
+
+# install build and runtime dependencies
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-adag.txt requirements-adag.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+RUN pip install packaging wheel
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install -r requirements-cuda.txt
+
+# cuda arch list used by torch
+# can be useful for both `dev` and `test`
+# explicitly set the list to avoid issues with torch 2.2
+# see https://github.com/pytorch/pytorch/pull/123243
+ARG torch_cuda_arch_list='6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX'
+ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
+#################### BASE BUILD IMAGE ####################
+
+
+#################### WHEEL BUILD IMAGE ####################
+FROM base AS build
+
+ARG PYTHON_VERSION=3
+
+# install build dependencies
+COPY requirements-build.txt requirements-build.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install -r requirements-build.txt
+
+# install compiler cache to speed up compilation leveraging local or remote caching
+RUN apt-get update -y && apt-get install -y ccache
+
+# files and directories related to build wheels
+COPY kernels kernels
+COPY setup.py setup.py
+COPY cmake cmake
+COPY CMakeLists.txt CMakeLists.txt
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-adag.txt requirements-adag.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+COPY pyproject.toml pyproject.toml
+COPY aphrodite aphrodite
+
+# max jobs used by Ninja to build extensions
+ARG max_jobs=2
+ENV MAX_JOBS=${max_jobs}
+# number of threads used by nvcc
+ARG nvcc_threads=8
+ENV NVCC_THREADS=${nvcc_threads}
+
+ENV CCACHE_DIR=/root/.cache/ccache
+RUN --mount=type=cache,target=/root/.cache/ccache \
+    --mount=type=cache,target=/root/.cache/pip \
+    python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
+
+#################### EXTENSION Build IMAGE ####################
+
+#################### DEV IMAGE ####################
+FROM base as dev
+
+COPY requirements-dev.txt requirements-dev.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install -r requirements-dev.txt
+
+#################### DEV IMAGE ####################
+
+#################### Aphrodite installation IMAGE ####################
+# image with Aphrodite installed
+FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS aphrodite-base
+ARG CUDA_VERSION=12.4.1
+WORKDIR /aphrodite-workspace
+
+RUN apt-get update -y \
+    && apt-get install -y python3-pip git vim
+
+# Workaround for https://github.com/openai/triton/issues/2507 and
+# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
+# this won't be needed for future versions of this docker image
+# or future versions of triton.
+RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
+
+# install aphrodite wheel first, so that torch etc will be installed
+RUN --mount=type=bind,from=build,src=/workspace/dist,target=/aphrodite-workspace/dist \
+    --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install dist/*.whl --verbose
+
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
+#################### Aphrodite installation IMAGE ####################
+
+
+#################### OPENAI API SERVER ####################
+# openai api server alternative
+FROM aphrodite-base AS aphrodite-openai
+
+# install additional dependencies for openai api server
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install accelerate hf_transfer 'modelscope!=1.15.0'
+
+ENV NUMBA_CACHE_DIR=$HOME/.numba_cache
+
+ENTRYPOINT ["python3", "-m", "aphrodite.endpoints.openai.api_server"]
+#################### OPENAI API SERVER ####################

+ 41 - 0
Dockerfile.cpu

@@ -0,0 +1,41 @@
+# This Aphrodite Dockerfile is used to construct image that can build and run Aphrodite on x86 CPU platform.
+
+FROM ubuntu:22.04 AS cpu-test-1
+
+RUN apt-get update -y \
+    && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
+    && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
+# intel-openmp provides additional performance improvement vs. openmp
+# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
+RUN pip install intel-openmp
+
+ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
+
+RUN echo 'ulimit -c 0' >> ~/.bashrc
+
+RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
+
+RUN pip install --upgrade pip \
+    && pip install wheel packaging ninja "setuptools>=49.4.0" numpy
+
+FROM cpu-test-1 AS build
+
+COPY ./ /workspace/aphrodite-engine
+
+WORKDIR /workspace/aphrodite-engine
+
+RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
+
+# Support for building with non-AVX512 Aphrodite: docker build --build-arg APHRODITE_CPU_DISABLE_AVX512="true" ...
+ARG APHRODITE_CPU_DISABLE_AVX512
+ENV APHRODITE_CPU_DISABLE_AVX512=${APHRODITE_CPU_DISABLE_AVX512}
+
+RUN APHRODITE_TARGET_DEVICE=cpu python3 setup.py install
+
+WORKDIR /workspace/
+
+RUN ln -s /workspace/aphrodite-engine/examples && ln -s /workspace/aphrodite-engine/tests/benchmarks
+
+ENTRYPOINT ["python3", "-m", "aphrodite.endpoints.openai.api_server"]

+ 0 - 0
docker/Dockerfile.neuron → Dockerfile.neuron


+ 32 - 0
Dockerfile.openvino

@@ -0,0 +1,32 @@
+# The Aphrodite Dockerfile is used to construct Aphrodite image that can be directly used
+# to run the OpenAI compatible server.
+
+FROM ubuntu:22.04 AS dev
+
+RUN apt-get update -y && \
+    apt-get install -y python3-pip git
+WORKDIR /workspace
+
+# copy requirements
+COPY requirements-build.txt /workspace/aphrodite-engine/
+COPY requirements-common.txt /workspace/aphrodite-engine/
+COPY requirements-openvino.txt /workspace/aphrodite-engine/
+
+COPY aphrodite/ /workspace/aphrodite-engine/aphrodite
+COPY kernels/core /workspace/aphrodite-engine/kernels/core
+COPY cmake/utils.cmake /workspace/aphrodite-engine/cmake/
+COPY CMakeLists.txt /workspace/aphrodite-engine/
+COPY setup.py /workspace/aphrodite-engine/
+
+# install build requirements
+RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/aphrodite-engine/requirements-build.txt
+# build Aphrodite with OpenVINO backend
+RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" APHRODITE_TARGET_DEVICE="openvino" python3 -m pip install -e /workspace/aphrodite-engine/
+
+# temporary hack until the dependencies are upgraded
+RUN pip install 'transformers==4.43.0'
+
+COPY examples/ /workspace/aphrodite-engine/examples
+COPY tests/benchmarks/ /workspace/aphrodite-engine/tests/benchmarks
+
+CMD ["/bin/bash"]

+ 22 - 0
Dockerfile.ppc64le

@@ -0,0 +1,22 @@
+FROM mambaorg/micromamba
+ARG MAMBA_DOCKERFILE_ACTIVATE=1
+USER root
+
+RUN apt-get update  -y     && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev     && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+# Some packages in requirements-cpu are installed here
+# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
+# Currently these may not be available for venv or pip directly
+RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults     python=3.10     pytorch-cpu=2.1.2     torchvision-cpu=0.16.2    &&     micromamba clean --all --yes
+
+COPY ./ /workspace/aphrodite-engine
+
+WORKDIR /workspace/aphrodite-engine
+
+# These packages will be in rocketce eventually
+RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
+
+RUN APHRODITE_TARGET_DEVICE=cpu python3 setup.py install
+
+WORKDIR /aphrodite-workspace
+ENTRYPOINT ["/opt/conda/bin/python3", "-m", "aphrodite.endpoints.openai.api_server"]

+ 180 - 0
Dockerfile.rocm

@@ -0,0 +1,180 @@
+# Default ROCm 6.1 base image
+ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
+
+# Default ROCm ARCHes to build Aphrodite for.
+ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
+
+# Whether to install CK-based flash-attention
+# If 0, will not install flash-attention
+ARG BUILD_FA="1"
+# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
+# If this succeeds, we use the downloaded wheel and skip building flash-attention.
+# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
+# architectures specified in `FA_GFX_ARCHS`
+ARG TRY_FA_WHEEL="1"
+ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
+ARG FA_GFX_ARCHS="gfx90a;gfx942"
+ARG FA_BRANCH="23a2b1c2"
+
+# Whether to build triton on rocm
+ARG BUILD_TRITON="1"
+ARG TRITON_BRANCH="e0fc12c"
+
+### Base image build stage
+FROM $BASE_IMAGE AS base
+
+# Import arg(s) defined before this build stage
+ARG PYTORCH_ROCM_ARCH
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+RUN apt-get update && apt-get install -y \
+    curl \
+    ca-certificates \
+    sudo \
+    git \
+    bzip2 \
+    libx11-6 \
+    build-essential \
+    wget \
+    unzip \
+    tmux \
+    ccache \
+ && rm -rf /var/lib/apt/lists/*
+
+# When launching the container, mount the code directory to /aphrodite-workspace
+ARG APP_MOUNT=/aphrodite-workspace
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+# Remove sccache so it doesn't interfere with ccache
+# TODO: implement sccache support across components
+RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
+# Install torch == 2.5.0 on ROCm
+RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
+        *"rocm-6.1"*) \
+            python3 -m pip uninstall -y torch torchvision \
+            && python3 -m pip install --no-cache-dir --pre \
+                torch==2.5.0.dev20240726 \
+                torchvision==0.20.0.dev20240726 \
+               --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
+        *) ;; esac
+
+ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
+ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
+ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
+
+ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
+ENV CCACHE_DIR=/root/.cache/ccache
+
+
+### AMD-SMI build stage
+FROM base AS build_amdsmi
+# Build amdsmi wheel always
+RUN cd /opt/rocm/share/amd_smi \
+    && python3 -m pip wheel . --wheel-dir=/install
+
+
+### Flash-Attention wheel build stage
+FROM base AS build_fa
+ARG BUILD_FA
+ARG TRY_FA_WHEEL
+ARG FA_WHEEL_URL
+ARG FA_GFX_ARCHS
+ARG FA_BRANCH
+# Build ROCm flash-attention wheel if `BUILD_FA = 1`
+RUN --mount=type=cache,target=${CCACHE_DIR} \
+    if [ "$BUILD_FA" = "1" ]; then \
+        if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
+            # If a suitable wheel exists, we download it instead of building FA
+            mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
+        else \
+            mkdir -p libs \
+            && cd libs \
+            && git clone https://github.com/ROCm/flash-attention.git \
+            && cd flash-attention \
+            && git checkout "${FA_BRANCH}" \
+            && git submodule update --init \
+            && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
+        fi; \
+    # Create an empty directory otherwise as later build stages expect one
+    else mkdir -p /install; \
+    fi
+
+
+### Triton wheel build stage
+FROM base AS build_triton
+ARG BUILD_TRITON
+ARG TRITON_BRANCH
+# Build triton wheel if `BUILD_TRITON = 1`
+RUN --mount=type=cache,target=${CCACHE_DIR} \
+    if [ "$BUILD_TRITON" = "1" ]; then \
+    mkdir -p libs \
+    && cd libs \
+    && git clone https://github.com/OpenAI/triton.git \
+    && cd triton \
+    && git checkout "${TRITON_BRANCH}" \
+    && cd python \
+    && python3 setup.py bdist_wheel --dist-dir=/install; \
+    # Create an empty directory otherwise as later build stages expect one
+    else mkdir -p /install; \
+    fi
+
+
+### Final Aphrodite build stage
+FROM base AS final
+# Import the Aphrodite development directory from the build context
+COPY . .
+
+# Package upgrades for useful functionality or to avoid dependency issues
+RUN --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
+
+# Workaround for ray >= 2.10.0
+ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
+# Silences the HF Tokenizers warning
+ENV TOKENIZERS_PARALLELISM=false
+
+RUN --mount=type=cache,target=${CCACHE_DIR} \
+    --mount=type=cache,target=/root/.cache/pip \
+    python3 -m pip install -Ur requirements-rocm.txt \
+    && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
+        *"rocm-6.1"*) \
+            # Bring in upgrades to HIP graph earlier than ROCm 6.2 for Aphrodite
+            wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \
+            # Prevent interference if torch bundles its own HIP runtime
+            && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
+        *) ;; esac \
+    && python3 setup.py clean --all \
+    && python3 setup.py develop
+
+# Copy amdsmi wheel into final image
+RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
+    mkdir -p libs \
+    && cp /install/*.whl libs \
+    # Preemptively uninstall to avoid same-version no-installs
+    && python3 -m pip uninstall -y amdsmi;
+
+# Copy triton wheel(s) into final image if they were built
+RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
+    mkdir -p libs \
+    && if ls /install/*.whl; then \
+        cp /install/*.whl libs \
+        # Preemptively uninstall to avoid same-version no-installs
+        && python3 -m pip uninstall -y triton; fi
+
+# Copy flash-attn wheel(s) into final image if they were built
+RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
+    mkdir -p libs \
+    && if ls /install/*.whl; then \
+        cp /install/*.whl libs \
+        # Preemptively uninstall to avoid same-version no-installs
+        && python3 -m pip uninstall -y flash-attn; fi
+
+# Install wheels that were built to the final image
+RUN --mount=type=cache,target=/root/.cache/pip \
+    if ls libs/*.whl; then \
+    python3 -m pip install libs/*.whl; fi
+
+CMD ["/bin/bash"]

+ 29 - 0
Dockerfile.tpu

@@ -0,0 +1,29 @@
+ARG NIGHTLY_DATE="20240726"
+ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
+
+FROM $BASE_IMAGE
+
+WORKDIR /workspace
+
+# Install aiohttp separately to avoid build errors.
+RUN pip install aiohttp
+# Install NumPy 1 instead of NumPy 2.
+RUN pip install "numpy<2"
+# Install the TPU and Pallas dependencies.
+RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
+RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
+
+# Fix FastAPI dependence
+RUN pip install "starlette<0.38.0"
+
+# Build Aphrodite.
+COPY . /workspace/aphrodite-engine
+ENV APHRODITE_TARGET_DEVICE="tpu"
+RUN cd /workspace/aphrodite-engine && python setup.py develop
+
+# Re-install outlines to avoid dependency errors.
+# The outlines version must follow requirements-common.txt.
+RUN pip uninstall outlines -y
+RUN pip install "outlines>=0.0.43"
+
+CMD ["/bin/bash"]

+ 22 - 0
Dockerfile.xpu

@@ -0,0 +1,22 @@
+FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04
+
+RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
+    echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
+    chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
+    rm /etc/apt/sources.list.d/intel-graphics.list && \
+    wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
+    echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
+    chmod 644 /usr/share/keyrings/intel-graphics.gpg
+
+RUN apt-get update  -y \
+&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip
+
+COPY ./ /workspace/aphrodite-engine
+
+WORKDIR /workspace/aphrodite-engine
+
+RUN pip install -v -r requirements-xpu.txt
+
+RUN APHRODITE_TARGET_DEVICE=xpu python3 setup.py develop
+
+CMD ["/bin/bash"]

+ 7 - 1
MANIFEST.in

@@ -1,6 +1,12 @@
 include LICENSE
+include requirements-adag.txt
 include requirements-common.txt
 include requirements-cuda.txt
+include requirements-rocm.txt
+include requirements-neuron.txt
+include requirements-cpu.txt
+include CMakeLists.txt
 include aphrodite/endpoints/kobold/klite.embd
 
-recursive-include kernels *
+recursive-include kernels *
+recursive-include cmake *

+ 14 - 6
aphrodite/__init__.py

@@ -1,23 +1,31 @@
+from aphrodite.common.outputs import (CompletionOutput, EmbeddingOutput,
+                                      EmbeddingRequestOutput, RequestOutput)
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.endpoints.llm import LLM
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_ray_cluster
-from aphrodite.endpoints.llm import LLM
+from aphrodite.executor.ray_utils import initialize_ray_cluster
 from aphrodite.modeling.models import ModelRegistry
-from aphrodite.common.outputs import CompletionOutput, RequestOutput
-from aphrodite.common.sampling_params import SamplingParams
 
-__version__ = "0.5.3"
+from .version import __commit__, __short_commit__, __version__
 
 __all__ = [
+    "__commit__",
+    "__short_commit__",
+    "__version__",
     "LLM",
     "ModelRegistry",
     "SamplingParams",
     "RequestOutput",
     "CompletionOutput",
+    "EmbeddingOutput",
+    "EmbeddingRequestOutput",
     "AphroditeEngine",
     "EngineArgs",
     "AsyncAphrodite",
     "AsyncEngineArgs",
     "initialize_ray_cluster",
+    "PoolingParams",
 ]

+ 176 - 0
aphrodite/_core_ext.py

@@ -0,0 +1,176 @@
+import importlib.util
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from loguru import logger
+
+core_C_available = importlib.util.find_spec('._core_C',
+                                            'aphrodite') is not None
+
+
+# Mirrors enum in `core/scalar_type.hpp`
+class NanRepr(Enum):
+    NONE = 0  # nans are not supported
+    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
+    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s
+
+
+if TYPE_CHECKING or not core_C_available:
+    # On platforms were we cannot use/build the C++ core extension (i.e. namely
+    # neuron and tpu), we define the mock ScalarType class here that partially
+    # mimics the C++ ScalarType class.
+    #
+    # We also use this provide type signatures to the Python LSP for the methods
+    # in the C++ ScalarType class. So these type signatures should be kept
+    # in sync with csrc/core/scalar_type.hpp
+
+    from dataclasses import dataclass
+
+    @dataclass(frozen=True)
+    class ScalarType:
+        """
+        ScalarType can represent a wide range of floating point and integer 
+        types, in particular it can be used to represent sub-byte data types 
+        (something that torch.dtype currently does not support). It is also 
+        capable of  representing types with a bias, i.e.:
+          `stored_value = value + bias`, 
+        this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias 
+        of 8). The implementation for this class can be found in 
+        csrc/core/scalar_type.hpp, these type signatures should be kept in sync 
+        with that file.
+        """
+
+        exponent: int
+        """
+        Number of bits in the exponent if this is a floating point type
+        (zero if this an integer type)
+        """
+
+        mantissa: int
+        """
+        Number of bits in the mantissa if this is a floating point type,
+        or the number bits representing an integer excluding the sign bit if 
+        this an integer type.
+        """
+
+        bias: int
+        """
+        bias used to encode the values in this scalar type 
+        (value = stored_value - bias, default 0) for example if we store the 
+        type as an unsigned integer with a bias of 128 then the value 0 will be 
+        stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
+        """
+
+        signed: bool
+        "If the type is signed (i.e. has a sign bit)"
+
+        _finite_values_only: bool = False
+        """
+        Private: if NANs are supported, used `has_infs()` instead.
+        """
+
+        nan_repr: int = NanRepr.IEEE_754.value
+        """
+        How NaNs are represent in this scalar type, returns NanRepr value. 
+        (not applicable for integer types)
+        """
+
+        @property
+        def size_bits(self):
+            return self.exponent + self.mantissa + int(self.signed)
+
+        def min(self) -> Union[int, float]:
+            """
+            Min representable value for this scalar type. 
+            (accounting for bias if there is one)
+            """
+            raise NotImplementedError
+
+        def max(self) -> Union[int, float]:
+            """
+            Max representable value for this scalar type. 
+            (accounting for bias if there is one)
+            """
+            raise NotImplementedError
+
+        def is_signed(self) -> bool:
+            """
+            If the type is signed (i.e. has a sign bit), same as `signed`
+            added for consistency with:
+            https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
+            """
+            ...
+
+        def is_floating_point(self):
+            "If the type is a floating point type"
+            return self.exponent != 0
+
+        def is_integer(self):
+            "If the type is an integer type"
+            return self.exponent == 0
+
+        def has_bias(self):
+            "If the type has a non-zero bias"
+            return self.bias != 0
+
+        def has_infs(self):
+            "If the type is floating point and supports infinity"
+            return not self._finite_values_only
+
+        def has_nans(self):
+            return self.nan_repr != NanRepr.NONE.value
+
+        def is_ieee_754(self) -> bool:
+            """
+            If the type is a floating point type that follows IEEE 754 
+            conventions
+            """
+            return self.nan_repr == NanRepr.IEEE_754.value and \
+                not self._finite_values_only
+
+        def __str__(self) -> str:
+            raise NotImplementedError
+
+        def __repr__(self) -> str:
+            raise NotImplementedError
+
+        #
+        # Convenience Constructors
+        #
+
+        @classmethod
+        def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+            "Create a signed integer scalar type (size_bits includes sign-bit)."
+            return cls(size_bits - 1, size_bits, bias if bias else 0, True)
+
+        @classmethod
+        def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+            """Create a unsigned integer scalar type."""
+            return cls(size_bits, size_bits, bias if bias else 0, False)
+
+        @classmethod
+        def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
+            """
+            Create a standard floating point type 
+            (i.e. follows IEEE 754 conventions).
+            """
+            return cls(exponent, mantissa, 0, True)
+
+        @classmethod
+        def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
+                   nan_repr: int):
+            """
+            Create a non-standard floating point type 
+            (i.e. does not follow IEEE 754 conventions).
+            """
+            return cls(exponent, mantissa, 0, True, finite_values_only,
+                       nan_repr)
+
+elif core_C_available:
+    try:
+        import aphrodite._core_C  # noqa: F401
+    except ImportError as e:
+        logger.warning(f"Failed to import from aphrodite._core_C with {e}")
+
+    ScalarType = torch.classes._core_C.ScalarType

+ 620 - 0
aphrodite/_custom_ops.py

@@ -0,0 +1,620 @@
+import contextlib
+import functools
+from typing import List, Optional, Tuple, Type
+
+import torch
+from loguru import logger
+
+from aphrodite._core_ext import ScalarType
+
+try:
+    import aphrodite._C
+except ImportError as e:
+    logger.warning(f"Failed to import from aphrodite._C with {e}")
+
+with contextlib.suppress(ImportError):
+    # ruff: noqa: F401
+    import aphrodite._moe_C
+
+
+def is_custom_op_supported(op_name: str) -> bool:
+    op, overloads = torch._C._jit_get_operation(op_name)
+    return op is not None
+
+
+def hint_on_error(fn):
+
+    @functools.wraps(fn)
+    def wrapper(*args, **kwargs):
+        try:
+            return fn(*args, **kwargs)
+        except AttributeError as e:
+            msg = (
+                f"Error in calling custom op {fn.__name__}: {e}\n"
+                f"Possibly you have built or installed an obsolete version of aphrodite.\n"
+                f"Please try a clean build and install of aphrodite,"
+                f"or remove old built files such as aphrodite/*.so and build/ ."
+            )
+            logger.error(msg)
+            raise e
+
+    return wrapper
+
+
+# activation ops
+def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.silu_and_mul(out, x)
+
+
+def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_and_mul(out, x)
+
+
+def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_tanh_and_mul(out, x)
+
+
+def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_fast(out, x)
+
+
+def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_new(out, x)
+
+
+def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
+    torch.ops._C.gelu_quick(out, x)
+
+
+# page attention ops
+def paged_attention_v1(
+    out: torch.Tensor,
+    query: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    num_kv_heads: int,
+    scale: float,
+    block_tables: torch.Tensor,
+    seq_lens: torch.Tensor,
+    block_size: int,
+    max_seq_len: int,
+    alibi_slopes: Optional[torch.Tensor],
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+    tp_rank: int = 0,
+    blocksparse_local_blocks: int = 0,
+    blocksparse_vert_stride: int = 0,
+    blocksparse_block_size: int = 64,
+    blocksparse_head_sliding_step: int = 0,
+) -> None:
+    torch.ops._C.paged_attention_v1(
+        out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
+        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
+        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
+        blocksparse_vert_stride, blocksparse_block_size,
+        blocksparse_head_sliding_step)
+
+
+def paged_attention_v2(
+    out: torch.Tensor,
+    exp_sum: torch.Tensor,
+    max_logits: torch.Tensor,
+    tmp_out: torch.Tensor,
+    query: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    num_kv_heads: int,
+    scale: float,
+    block_tables: torch.Tensor,
+    seq_lens: torch.Tensor,
+    block_size: int,
+    max_seq_len: int,
+    alibi_slopes: Optional[torch.Tensor],
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+    tp_rank: int = 0,
+    blocksparse_local_blocks: int = 0,
+    blocksparse_vert_stride: int = 0,
+    blocksparse_block_size: int = 64,
+    blocksparse_head_sliding_step: int = 0,
+) -> None:
+    torch.ops._C.paged_attention_v2(
+        out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
+        num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
+        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
+        blocksparse_local_blocks, blocksparse_vert_stride,
+        blocksparse_block_size, blocksparse_head_sliding_step)
+
+
+# pos encoding ops
+def rotary_embedding(
+    positions: torch.Tensor,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    head_size: int,
+    cos_sin_cache: torch.Tensor,
+    is_neox: bool,
+) -> None:
+    torch.ops._C.rotary_embedding(positions, query, key, head_size,
+                                  cos_sin_cache, is_neox)
+
+
+def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
+                             key: torch.Tensor, head_size: int,
+                             cos_sin_cache: torch.Tensor, is_neox: bool,
+                             rot_dim: int,
+                             cos_sin_cache_offsets: torch.Tensor) -> None:
+    torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
+                                          cos_sin_cache, is_neox, rot_dim,
+                                          cos_sin_cache_offsets)
+
+
+# layer norm ops
+def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
+             epsilon: float) -> None:
+    torch.ops._C.rms_norm(out, input, weight, epsilon)
+
+
+def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
+                       weight: torch.Tensor, epsilon: float) -> None:
+    torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
+
+
+def advance_step(num_seqs: int, num_queries: int, block_size: int,
+                 input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
+                 input_positions: torch.Tensor, seq_lens: torch.Tensor,
+                 slot_mapping: torch.Tensor,
+                 block_tables: torch.Tensor) -> None:
+    """Advance a step on GPU for existing inputs for a multi-step runner"""
+    return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
+                                     input_tokens, sampled_token_ids,
+                                     input_positions, seq_lens, slot_mapping,
+                                     block_tables)
+
+
+# quantization ops
+# awq
+def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
+                   zeros: torch.Tensor, split_k_iters: int, thx: int,
+                   thy: int) -> torch.Tensor:
+    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
+                                       thx, thy)
+
+
+def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
+             scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
+    return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
+
+
+# gptq
+def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
+              b_g_idx: torch.Tensor, use_exllama: bool,
+              bit: int) -> torch.Tensor:
+    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
+                                  b_g_idx, use_exllama, bit)
+
+
+def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
+                 bit: int) -> None:
+    torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
+
+
+# squeezellm
+def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
+                    lookup_table: torch.Tensor) -> None:
+    torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
+
+
+# marlin
+def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
+                size_n: int, size_k: int) -> torch.Tensor:
+    return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
+                                    size_n, size_k)
+
+
+# marlin_24
+def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                        b_meta: torch.Tensor, b_scales: torch.Tensor,
+                        workspace: torch.Tensor, b_q_type: ScalarType,
+                        size_m: int, size_n: int, size_k: int) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
+                                            workspace, b_q_type, size_m,
+                                            size_n, size_k)
+
+
+# fp8 marlin
+def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                    b_scales: torch.Tensor, workspace: torch.Tensor,
+                    num_bits: int, size_m: int, size_n: int,
+                    size_k: int) -> torch.Tensor:
+    return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
+                                        num_bits, size_m, size_n, size_k)
+
+
+# cutlass
+def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
+    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
+
+
+def cutlass_scaled_mm(a: torch.Tensor,
+                      b: torch.Tensor,
+                      scale_a: torch.Tensor,
+                      scale_b: torch.Tensor,
+                      out_dtype: Type[torch.dtype],
+                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
+    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
+
+    m = a.shape[0]
+    n = b.shape[1]
+    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
+
+    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
+
+    return out
+
+
+# aqlm
+def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
+              codebooks: torch.Tensor, scales: torch.Tensor,
+              codebook_partition_sizes: torch.Tensor,
+              bias: Optional[torch.Tensor]) -> torch.Tensor:
+    return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
+                                  codebook_partition_sizes, bias)
+
+
+def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
+                 codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
+    return torch.ops._C.aqlm_dequant(codes, codebooks,
+                                     codebook_partition_sizes)
+
+
+# gptq_marlin
+def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
+                       size_k: int, size_n: int,
+                       num_bits: int) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
+                                           num_bits)
+
+
+def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
+                      num_bits: int) -> torch.Tensor:
+    return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
+
+
+def gptq_marlin_gemm(a: torch.Tensor,
+                     b_q_weight: torch.Tensor,
+                     b_scales: torch.Tensor,
+                     b_zeros: torch.Tensor,
+                     g_idx: torch.Tensor,
+                     perm: torch.Tensor,
+                     workspace: torch.Tensor,
+                     b_q_type: ScalarType,
+                     size_m: int,
+                     size_n: int,
+                     size_k: int,
+                     is_k_full: bool,
+                     has_zp: bool = False,
+                     use_fp32_reduce: bool = False) -> torch.Tensor:
+    return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
+                                         g_idx, perm, workspace, b_q_type,
+                                         size_m, size_n, size_k, is_k_full,
+                                         has_zp, use_fp32_reduce)
+
+
+# fp8
+def scaled_fp8_quant(
+    input: torch.Tensor,
+    scale: Optional[torch.Tensor] = None,
+    num_token_padding: Optional[int] = None,
+    scale_ub: Optional[torch.Tensor] = None,
+    use_per_token_if_dynamic: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to FP8 and return quantized tensor and scale.
+
+    This function supports both static and dynamic quantization: If you
+    provide the scale, it will use static scaling and if you omit it,
+    the scale will be determined dynamically. The function also allows
+    optional padding of the output tensors for downstream kernels that
+    will benefit from padding.
+
+    Args:
+        input: The input tensor to be quantized to FP8
+        scale: Optional scaling factor for the FP8 quantization
+        num_token_padding: If specified, pad the first dimension
+            of the output to at least this value.
+         use_per_token_if_dynamic: Whether to do per_tensor or per_token 
+            in the dynamic quantization case.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
+            scaling factor.
+    """
+    # This code assumes batch_dim and num_tokens are flattened
+    assert (input.ndim == 2)
+    shape = input.shape
+    if num_token_padding:
+        shape = (max(num_token_padding, input.shape[0]), shape[1])
+    output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
+
+    if scale is None:
+        if use_per_token_if_dynamic:
+            scale = torch.empty((shape[0], 1),
+                                device=input.device,
+                                dtype=torch.float32)
+            torch.ops._C.dynamic_per_token_scaled_fp8_quant(
+                output, input, scale, scale_ub)
+        else:
+            scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+            torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
+    else:
+        # num_token_padding not implemented for this case
+        assert (scale.numel() == 1 or num_token_padding is None)
+        torch.ops._C.static_scaled_fp8_quant(output, input, scale)
+
+    return output, scale
+
+
+# int8
+def scaled_int8_quant(
+        input: torch.Tensor,
+        scale: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize the input tensor to int8 and return the quantized tensor and scale.
+
+    Args:
+        input: The input tensor to be quantized to int8.
+        scale: Optional scaling factor for the int8 quantization.
+            When not provided, we invoke dynamic-per-token quantization.
+
+    Returns:
+      Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
+    """
+    output = torch.empty_like(input, dtype=torch.int8)
+    if scale is not None:
+        # static-per-tensor quantization.
+        torch.ops._C.static_scaled_int8_quant(output, input, scale)
+        return output, scale
+
+    # dynamic-per-token quantization.
+    input_scales = torch.empty((input.numel() // input.shape[-1], 1),
+                               device=input.device,
+                               dtype=torch.float32)
+    torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
+    return output, input_scales
+
+
+# quip#
+def quip_gemv(
+    A: torch.Tensor,
+    B: torch.Tensor,
+    CB: torch.Tensor,
+) -> torch.Tensor:
+    return torch.ops._C.quip_gemv(A, B, CB)
+
+
+def quip_decompress(
+    YIs: torch.Tensor,
+    CB: torch.Tensor,
+    Y: torch.Tensor,
+) -> torch.Tensor:
+    return torch.ops._C.quip_decompress(YIs, CB, Y)
+
+
+# qqq ops
+def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+                    s_tok: torch.Tensor, s_ch: torch.Tensor,
+                    s_group: torch.Tensor, workspace: torch.Tensor,
+                    size_m: int, size_n: int, size_k: int) -> torch.Tensor:
+    return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
+                                        workspace, size_m, size_n, size_k)
+
+
+# gguf
+def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
+    return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
+
+
+def ggml_mul_mat_vec(
+    W: torch.Tensor,
+    X: torch.Tensor,
+    quant_type: int,
+    row: int,
+):
+    return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
+
+
+def ggml_mul_mat_vec_a8(
+    W: torch.Tensor,
+    X: torch.Tensor,
+    quant_type: int,
+    row: int,
+):
+    return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
+
+
+def ggml_mul_mat_a8(
+    W: torch.Tensor,
+    X: torch.Tensor,
+    quant_type: int,
+    row: int,
+):
+    return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
+
+
+# mamba
+def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
+                      bias_: Optional[torch.Tensor],
+                      seq_idx_: Optional[torch.Tensor],
+                      initial_states_: Optional[torch.Tensor],
+                      final_states_out_: Optional[torch.Tensor],
+                      silu_activation: bool) -> torch.Tensor:
+    return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, None,
+                                          initial_states_, final_states_out_,
+                                          silu_activation)
+
+
+def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
+                         weight: torch.Tensor, bias_: Optional[torch.Tensor],
+                         silu_activation: bool) -> torch.Tensor:
+    return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
+                                             silu_activation)
+
+
+def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
+                       B: torch.Tensor, C: torch.Tensor,
+                       D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
+                       delta_bias_: Optional[torch.Tensor],
+                       delta_softplus: bool, index_: Optional[torch.Tensor],
+                       x: Optional[torch.Tensor]) -> List[torch.Tensor]:
+    return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
+                                           delta_bias_, delta_softplus, index_,
+                                           x)
+
+
+# moe
+def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
+                         block_size: int, sorted_token_ids: torch.Tensor,
+                         experts_ids: torch.Tensor,
+                         num_tokens_post_pad: torch.Tensor) -> None:
+    torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
+                                      sorted_token_ids, experts_ids,
+                                      num_tokens_post_pad)
+
+
+def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
+                 token_expert_indicies: torch.Tensor,
+                 gating_output: float) -> None:
+    torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
+                                  token_expert_indicies, gating_output)
+
+
+def reshape_and_cache(
+    key: torch.Tensor,
+    value: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    slot_mapping: torch.Tensor,
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+) -> None:
+    torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
+                                             value_cache, slot_mapping,
+                                             kv_cache_dtype, k_scale, v_scale)
+
+
+def reshape_and_cache_flash(
+    key: torch.Tensor,
+    value: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    slot_mapping: torch.Tensor,
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+) -> None:
+    torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
+                                                   value_cache, slot_mapping,
+                                                   kv_cache_dtype, k_scale,
+                                                   v_scale)
+
+
+def copy_blocks(key_caches: List[torch.Tensor],
+                value_caches: List[torch.Tensor],
+                block_mapping: torch.Tensor) -> None:
+    torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
+
+
+def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
+                block_mapping: torch.Tensor) -> None:
+    torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
+
+
+def convert_fp8(output: torch.Tensor,
+                input: torch.Tensor,
+                scale: float = 1.0,
+                kv_dtype: str = "fp8") -> None:
+    torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
+
+
+def get_device_attribute(attribute: int, device: int) -> int:
+    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
+
+
+def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
+    # ruff: noqa: E501
+    return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
+        device)
+
+
+# custom ar
+def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
+                   handles: List[str], offsets: List[int], rank: int,
+                   full_nvlink: bool) -> int:
+    return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
+                                                 offsets, rank, full_nvlink)
+
+
+def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
+                     full_nvlink: bool) -> bool:
+    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
+                                                   full_nvlink)
+
+
+def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
+    torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
+
+
+def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
+                     out: torch.Tensor) -> None:
+    torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
+
+
+def dispose(fa: int) -> None:
+    torch.ops._C_custom_ar.dispose(fa)
+
+
+def meta_size() -> int:
+    return torch.ops._C_custom_ar.meta_size()
+
+
+def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
+                    offsets: List[int]) -> None:
+    return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
+
+
+def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
+    return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
+
+
+def register_graph_buffers(fa: int, handles: List[str],
+                           offsets: List[List[int]]) -> None:
+    torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
+
+
+# TODO: remove this later
+names_and_values = globals()
+names_and_values_to_update = {}
+# prepare variables to avoid dict size change during iteration
+k, v, arg = None, None, None
+fn_type = type(lambda x: x)
+for k, v in names_and_values.items():
+    # find functions that are defined in this file and have torch.Tensor
+    # in their annotations. `arg == "torch.Tensor"` is used to handle
+    # the case when users use `import __annotations__` to turn type
+    # hints into strings.
+    if isinstance(v, fn_type) \
+        and v.__code__.co_filename == __file__ \
+        and any(arg is torch.Tensor or arg == "torch.Tensor"
+                   for arg in v.__annotations__.values()):
+        names_and_values_to_update[k] = hint_on_error(v)
+
+names_and_values.update(names_and_values_to_update)
+del names_and_values_to_update, names_and_values, v, k, fn_type

+ 241 - 0
aphrodite/_ipex_ops.py

@@ -0,0 +1,241 @@
+from typing import List, Optional, Tuple
+
+import torch
+from loguru import logger
+
+try:
+    import intel_extension_for_pytorch as ipex
+except ImportError as e:
+    logger.warning(f"Import error msg: {e.msg}")
+
+
+class ipex_ops:
+
+    @staticmethod
+    def _reshape_activation_tensor(
+            x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        num = x.size(0)
+        d = x.size(1) // 2
+        x = x.reshape(num, 2, d)
+        x1, x2 = torch.chunk(x, chunks=2, dim=1)
+        x1 = x1.reshape(num, d)
+        x2 = x2.reshape(num, d)
+        return x1, x2
+
+    def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+        x1, x2 = ipex_ops._reshape_activation_tensor(x)
+        ipex.llm.functional.silu_mul(x1, x2, out)
+
+    def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+        x1, x2 = ipex_ops._reshape_activation_tensor(x)
+        ipex.llm.functional.gelu_mul(x1, x2, out, "none")
+
+    def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+        x1, x2 = ipex_ops._reshape_activation_tensor(x)
+        ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
+
+    def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
+        out.copy_(torch.nn.functional.gelu(x))
+
+    def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
+        out.copy_(torch.nn.functional.gelu(x))
+
+    def paged_attention_v1(
+        out: torch.Tensor,
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        num_kv_heads: int,
+        scale: float,
+        block_tables: torch.Tensor,
+        context_lens: torch.Tensor,
+        block_size: int,
+        max_context_len: int,
+        alibi_slopes: Optional[torch.Tensor],
+        kv_cache_dtype: str,
+        k_scale: float,
+        v_scale: float,
+        tp_rank: int = 0,
+        blocksparse_local_blocks: int = 0,
+        blocksparse_vert_stride: int = 0,
+        blocksparse_block_size: int = 64,
+        blocksparse_head_sliding_step: int = 0,
+    ) -> None:
+        assert kv_cache_dtype == "auto"
+        num_heads = out.size(1)
+        num_queries_per_tokens = num_heads // num_kv_heads
+        head_mapping = torch.arange(
+            0,
+            num_kv_heads,
+            device=query.device,
+            dtype=torch.int32,
+        ).view(num_kv_heads,
+               1).repeat_interleave(num_queries_per_tokens).flatten()
+        # todo: ipex will refactor namespace
+        torch.xpu.paged_attention_v1(out, query.contiguous(),
+                                     key_cache.view_as(value_cache),
+                                     value_cache, head_mapping, scale,
+                                     block_tables, context_lens, block_size,
+                                     max_context_len, alibi_slopes)
+
+    def paged_attention_v2(
+        out: torch.Tensor,
+        exp_sum: torch.Tensor,
+        max_logits: torch.Tensor,
+        tmp_out: torch.Tensor,
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        num_kv_heads: int,
+        scale: float,
+        block_tables: torch.Tensor,
+        context_lens: torch.Tensor,
+        block_size: int,
+        max_context_len: int,
+        alibi_slopes: Optional[torch.Tensor],
+        kv_cache_dtype: str,
+        k_scale: float,
+        v_scale: float,
+        tp_rank: int = 0,
+        blocksparse_local_blocks: int = 0,
+        blocksparse_vert_stride: int = 0,
+        blocksparse_block_size: int = 64,
+        blocksparse_head_sliding_step: int = 0,
+    ) -> None:
+        assert kv_cache_dtype == "auto"
+        num_heads = out.size(1)
+        num_queries_per_tokens = num_heads // num_kv_heads
+        head_mapping = torch.arange(
+            0,
+            num_kv_heads,
+            dtype=torch.int32,
+            device=query.device,
+        ).view(num_kv_heads,
+               1).repeat_interleave(num_queries_per_tokens).flatten()
+        # todo: ipex will refactor namespace
+        torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
+                                     query.contiguous(),
+                                     key_cache.view_as(value_cache),
+                                     value_cache, head_mapping, block_tables,
+                                     context_lens, scale, block_size,
+                                     max_context_len, alibi_slopes)
+
+    def rotary_embedding(
+        positions: torch.Tensor,  # [batch_size, seq_len]
+        query: torch.Tensor,  # [batch_size, seq_len, num_heads*head_size]
+        key: torch.Tensor,  # [batch_size, seq_len, num_kv_heads*head_size]
+        head_size: int,
+        cos_sin_cache: torch.Tensor,  # [cos_sin_dim, rot_dim]
+        is_neox: bool,
+    ) -> None:
+        if positions.dim() == 1:
+            positions = positions.unsqueeze(0)
+            query = query.unsqueeze(0)
+            key = key.unsqueeze(0)
+
+        rotary_dim = cos_sin_cache.size(1)
+        query = query.view(*query.shape[:-1], -1, head_size)
+        key = key.view(*key.shape[:-1], -1, head_size)
+
+        query_rot = query[..., :rotary_dim]
+        key_rot = key[..., :rotary_dim]
+
+        cos_sin = cos_sin_cache[positions.long()]
+        cos, sin = cos_sin.chunk(2, dim=-1)
+
+        if is_neox:
+            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
+            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
+        else:
+            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
+            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
+        ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
+                                             rotary_dim, is_neox, positions)
+
+    def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
+                                 key: torch.Tensor, head_size: int,
+                                 cos_sin_cache: torch.Tensor, is_neox: bool,
+                                 rot_dim: int,
+                                 cos_sin_cache_offsets: torch.Tensor) -> None:
+        if positions.dim() == 1:
+            positions = positions.unsqueeze(0)
+            query = query.unsqueeze(0)
+            key = key.unsqueeze(0)
+        cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
+        rotary_dim = cos_sin_cache.size(1)
+        query = query.view(*query.shape[:-1], -1, head_size)
+        key = key.view(*key.shape[:-1], -1, head_size)
+
+        query_rot = query[..., :rotary_dim]
+        key_rot = key[..., :rotary_dim]
+
+        cos_sin = cos_sin_cache[torch.add(positions,
+                                          cos_sin_cache_offsets).long()]
+        cos, sin = cos_sin.chunk(2, dim=-1)
+
+        if is_neox:
+            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
+            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
+        else:
+            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
+            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
+
+        ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
+                                             rotary_dim, is_neox, positions)
+
+    def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
+                 epsilon: float) -> None:
+        tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
+        out.copy_(tmp)
+
+    def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
+                           weight: torch.Tensor, epsilon: float) -> None:
+        tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
+                                               epsilon, True)
+        input.copy_(tmp)
+
+    def varlen_attention(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        out: torch.Tensor,
+        seqlen_q: torch.Tensor,
+        seqlen_k: torch.Tensor,
+        max_seqlen_q: int,
+        max_seqlen_k: int,
+        pdropout: float,
+        softmax_scale: float,
+        zero_tensors: bool,
+        is_causal: bool,
+        return_softmax: bool,
+        gen_: torch.Generator,
+    ) -> None:
+        ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
+                                             seqlen_k, max_seqlen_q,
+                                             max_seqlen_k, pdropout,
+                                             softmax_scale, zero_tensors,
+                                             is_causal, return_softmax, gen_)
+
+    def reshape_and_cache(
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        slot_mapping: torch.Tensor,
+        kv_cache_dtype: str,
+        k_scale: float,
+        v_scale: float,
+    ) -> None:
+        assert kv_cache_dtype == "auto"
+        ipex.llm.modules.PagedAttention.reshape_and_cache(
+            key, value, key_cache, value_cache, slot_mapping)
+
+    @staticmethod
+    def copy_blocks(key_caches: List[torch.Tensor],
+                    value_caches: List[torch.Tensor],
+                    block_mapping: torch.Tensor) -> None:
+        torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
+
+    def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
+                    block_mapping: torch.Tensor) -> None:
+        torch.xpu.swap_blocks(src, dst, block_mapping)

+ 0 - 0
aphrodite/adapter_commons/__init__.py


+ 14 - 0
aphrodite/adapter_commons/layers.py

@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from typing import Tuple
+
+
+@dataclass
+class AdapterMapping:
+    # Per every token in input_ids:
+    index_mapping: Tuple[int, ...]
+    # Per sampled token:
+    prompt_mapping: Tuple[int, ...]
+
+    def __post_init__(self):
+        self.index_mapping = tuple(self.index_mapping)
+        self.prompt_mapping = tuple(self.prompt_mapping)

+ 102 - 0
aphrodite/adapter_commons/models.py

@@ -0,0 +1,102 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
+
+from loguru import logger
+from torch import nn
+
+from aphrodite.common.utils import LRUCache
+
+
+class AdapterModel(ABC):
+
+    def __init__(self, model_id=None):
+        self.id = model_id
+
+    @abstractmethod
+    def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
+        # Common initialization code
+        # Load weights or embeddings from local checkpoint
+        raise NotImplementedError("Subclasses must implement this method.")
+
+
+T = TypeVar('T')
+
+
+class AdapterLRUCache(LRUCache[T]):
+
+    def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
+                                                              None]):
+        super().__init__(capacity)
+        self.deactivate_fn = deactivate_fn
+
+    def _on_remove(self, key: Hashable, value: T):
+        logger.debug(f"Removing adapter int id: {key}")
+        self.deactivate_fn(key)
+        return super()._on_remove(key, value)
+
+
+class AdapterModelManager(ABC):
+
+    def __init__(
+        self,
+        model: nn.Module,
+    ):
+        """Create a AdapterModelManager and adapter for a given model.
+        Args:
+            model: the model to be adapted.
+        """
+        self.model: nn.Module = model
+        self._registered_adapters: Dict[int, Any] = {}
+        # Dict instead of a Set for compatibility with LRUCache.
+        self._active_adapters: Dict[int, None] = {}
+        self.adapter_type = 'Adapter'
+        self._last_mapping = None
+
+    def __len__(self) -> int:
+        return len(self._registered_adapters)
+
+    @property
+    @abstractmethod
+    def adapter_slots(self):
+        ...
+
+    @property
+    @abstractmethod
+    def capacity(self):
+        ...
+
+    @abstractmethod
+    def activate_adapter(self, adapter_id: int) -> bool:
+        ...
+
+    @abstractmethod
+    def deactivate_adapter(self, adapter_id: int) -> bool:
+        ...
+
+    @abstractmethod
+    def add_adapter(self, adapter: Any) -> bool:
+        ...
+
+    @abstractmethod
+    def set_adapter_mapping(self, mapping: Any) -> None:
+        ...
+
+    @abstractmethod
+    def remove_adapter(self, adapter_id: int) -> bool:
+        ...
+
+    @abstractmethod
+    def remove_all_adapters(self):
+        ...
+
+    @abstractmethod
+    def get_adapter(self, adapter_id: int) -> Optional[Any]:
+        ...
+
+    @abstractmethod
+    def list_adapters(self) -> Dict[int, Any]:
+        ...
+
+    @abstractmethod
+    def pin_adapter(self, adapter_id: int) -> bool:
+        ...

+ 25 - 0
aphrodite/adapter_commons/request.py

@@ -0,0 +1,25 @@
+from abc import abstractmethod
+from dataclasses import dataclass
+
+
+@dataclass
+class AdapterRequest:
+    """
+    Base class for adapter requests.
+    """
+
+    @property
+    @abstractmethod
+    def adapter_id(self):
+        ...
+
+    def __post_init__(self):
+        if self.adapter_id < 1:
+            raise ValueError(f"id must be > 0, got {self.adapter_id}")
+
+    def __eq__(self, value: object) -> bool:
+        return isinstance(
+            value, self.__class__) and self.adapter_id == value.adapter_id
+
+    def __hash__(self) -> int:
+        return hash(self.adapter_id)

+ 90 - 0
aphrodite/adapter_commons/utils.py

@@ -0,0 +1,90 @@
+from typing import Any, Callable, Dict, Optional, Set
+
+
+## model functions
+def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
+                       deactivate_func: Callable) -> bool:
+    if adapter_id in active_adapters:
+        deactivate_func(adapter_id)
+        active_adapters.pop(adapter_id)
+        return True
+    return False
+
+
+def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
+                capacity: int, add_func: Callable) -> bool:
+    if adapter.id not in registered_adapters:
+        if len(registered_adapters) >= capacity:
+            raise RuntimeError('No free adapter slots.')
+        add_func(adapter)
+        registered_adapters[adapter.id] = adapter
+        return True
+    return False
+
+
+def set_adapter_mapping(mapping: Any, last_mapping: Any,
+                        set_mapping_func: Callable) -> Any:
+    if last_mapping != mapping:
+        set_mapping_func(mapping)
+        return mapping
+    return last_mapping
+
+
+def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
+                   deactivate_func: Callable) -> bool:
+    deactivate_func(adapter_id)
+    return bool(registered_adapters.pop(adapter_id, None))
+
+
+def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
+    return dict(registered_adapters)
+
+
+def get_adapter(adapter_id: int,
+                registered_adapters: Dict[int, Any]) -> Optional[Any]:
+    return registered_adapters.get(adapter_id, None)
+
+
+## worker functions
+def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
+                               apply_adapters_func,
+                               set_adapter_mapping_func) -> None:
+    apply_adapters_func(requests)
+    set_adapter_mapping_func(mapping)
+
+
+def add_adapter_worker(adapter_request: Any, list_adapters_func,
+                       load_adapter_func, add_adapter_func,
+                       activate_adapter_func) -> bool:
+    if adapter_request.adapter_id in list_adapters_func():
+        return False
+    loaded_adapter = load_adapter_func(adapter_request)
+    loaded = add_adapter_func(loaded_adapter)
+    activate_adapter_func(loaded_adapter.id)
+    return loaded
+
+
+def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
+                          adapter_slots: int, remove_adapter_func,
+                          add_adapter_func) -> None:
+    models_that_exist = list_adapters_func()
+    models_map = {
+        adapter_request.adapter_id: adapter_request
+        for adapter_request in adapter_requests if adapter_request
+    }
+    if len(models_map) > adapter_slots:
+        raise RuntimeError(
+            f"Number of requested models ({len(models_map)}) is greater "
+            f"than the number of GPU model slots "
+            f"({adapter_slots}).")
+    new_models = set(models_map)
+    models_to_add = new_models - models_that_exist
+    models_to_remove = models_that_exist - new_models
+    for adapter_id in models_to_remove:
+        remove_adapter_func(adapter_id)
+    for adapter_id in models_to_add:
+        add_adapter_func(models_map[adapter_id])
+
+
+def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
+    return set(adapter_manager_list_adapters_func())

+ 36 - 0
aphrodite/adapter_commons/worker_manager.py

@@ -0,0 +1,36 @@
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Set
+
+import torch
+
+
+class AbstractWorkerManager(ABC):
+
+    def __init__(self, device: torch.device):
+        self.device = device
+
+    @property
+    @abstractmethod
+    def is_enabled(self) -> bool:
+        ...
+
+    @abstractmethod
+    def set_active_adapters(self, requests: Set[Any],
+                            mapping: Optional[Any]) -> None:
+        ...
+
+    @abstractmethod
+    def add_adapter(self, adapter_request: Any) -> bool:
+        ...
+
+    @abstractmethod
+    def remove_adapter(self, adapter_id: int) -> bool:
+        ...
+
+    @abstractmethod
+    def remove_all_adapters(self):
+        ...
+
+    @abstractmethod
+    def list_adapters(self) -> Set[int]:
+        ...

+ 4 - 6
aphrodite/attention/__init__.py

@@ -1,15 +1,13 @@
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataBuilder)
 from aphrodite.attention.layer import Attention
 from aphrodite.attention.selector import get_attn_backend
 
 __all__ = [
     "AttentionBackend",
     "AttentionMetadata",
+    "AttentionMetadataBuilder",
     "Attention",
     "get_attn_backend",
-    "AttentionMetadataPerStage",
 ]

+ 87 - 39
aphrodite/attention/backends/abstract.py

@@ -1,13 +1,30 @@
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, fields
-from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
+from enum import Enum, auto
+from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
+                    Tuple, Type, TypeVar)
 
 import torch
 
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner_base import (
+        ModelRunnerInputBuilderBase)
+
+
+class AttentionType(Enum):
+    DECODER = auto()  # Decoder attention between previous layer Q/K/V
+    ENCODER = auto()  # Encoder attention between previous layer Q/K/V
+    ENCODER_DECODER = auto()  # Attention between dec. Q and enc. K/V
+
 
 class AttentionBackend(ABC):
     """Abstract class for attention backends."""
 
+    @staticmethod
+    @abstractmethod
+    def get_name() -> str:
+        raise NotImplementedError
+
     @staticmethod
     @abstractmethod
     def get_impl_cls() -> Type["AttentionImpl"]:
@@ -15,9 +32,23 @@ class AttentionBackend(ABC):
 
     @staticmethod
     @abstractmethod
-    def make_metadata(*args, **kwargs) -> "AttentionMetadata":
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        raise NotImplementedError
+
+    @classmethod
+    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
+        return cls.get_metadata_cls()(*args, **kwargs)
+
+    @staticmethod
+    @abstractmethod
+    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
         raise NotImplementedError
 
+    @classmethod
+    def make_metadata_builder(cls, *args,
+                              **kwargs) -> "AttentionMetadataBuilder":
+        return cls.get_builder_cls()(*args, **kwargs)
+
     @staticmethod
     @abstractmethod
     def get_kv_cache_shape(
@@ -33,7 +64,7 @@ class AttentionBackend(ABC):
     def swap_blocks(
         src_kv_cache: torch.Tensor,
         dst_kv_cache: torch.Tensor,
-        src_to_dst: Dict[int, int],
+        src_to_dst: torch.Tensor,
     ) -> None:
         raise NotImplementedError
 
@@ -41,30 +72,13 @@ class AttentionBackend(ABC):
     @abstractmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
         raise NotImplementedError
 
 
 @dataclass
-class AttentionMetadataPerStage:
-    """Attention metadata for a specific stage. I.e., prefill or decode."""
-
-    def asdict_zerocopy(self) -> Dict[str, Any]:
-        """Similar to dataclasses.asdict, but avoids deepcopying."""
-        # Note that if we add dataclasses as fields, they will need
-        # similar handling.
-        return {
-            field.name: getattr(self, field.name)
-            for field in fields(self)
-        }
-
-
-T = TypeVar("T", bound=AttentionMetadataPerStage)
-
-
-@dataclass
-class AttentionMetadata(Generic[T]):
+class AttentionMetadata:
     """Attention metadata for prefill and decode batched together."""
     # Total number of prefill requests.
     num_prefills: int
@@ -73,29 +87,58 @@ class AttentionMetadata(Generic[T]):
     # Number of decode tokens. Note that it is equivalent to the number of
     # decode requests.
     num_decode_tokens: int
-    # The attention metadata for prefill requests in a batch.
-    # None if there's no prefill requests in a batch.
-    prefill_metadata: Optional[T]
-    # The attention metadata for decode requests in a batch.
-    # None if there's no decode requests in a batch.
-    decode_metadata: Optional[T]
     # (num_tokens,). The indices of the token slots that input tokens will be
     # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
     # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
     # in block 0, and 1st slot in block 1, respectively.
     slot_mapping: torch.Tensor
-    # The kv cache's data type.
-    kv_cache_dtype: str
 
-    def __post_init__(self):
-        if self.num_prefill_tokens > 0:
-            assert self.num_prefills > 0
-            assert self.prefill_metadata is not None
-        if self.num_decode_tokens > 0:
-            assert self.decode_metadata is not None
+    @property
+    @abstractmethod
+    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
+        """Return the attention metadata that's required to run prefill
+        attention."""
+        pass
+
+    @property
+    @abstractmethod
+    def decode_metadata(self) -> Optional["AttentionMetadata"]:
+        """Return the attention metadata that's required to run decode
+        attention."""
+        pass
+
+    def asdict_zerocopy(self,
+                        skip_fields: Optional[Set[str]] = None
+                        ) -> Dict[str, Any]:
+        """Similar to dataclasses.asdict, but avoids deepcopying."""
+        if skip_fields is None:
+            skip_fields = set()
+        # Note that if we add dataclasses as fields, they will need
+        # similar handling.
+        return {
+            field.name: getattr(self, field.name)
+            for field in fields(self) if field.name not in skip_fields
+        }
+
+
+T = TypeVar("T", bound=AttentionMetadata)
+
+
+class AttentionMetadataBuilder(ABC, Generic[T]):
+    """Abstract class for attention metadata builders."""
+
+    @abstractmethod
+    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def build(self, seq_lens: List[int], query_lens: List[int],
+              cuda_graph_pad_size: int, batch_size: int) -> T:
+        """Build attention metadata with on-device tensors."""
+        raise NotImplementedError
 
 
-class AttentionImpl(ABC):
+class AttentionImpl(ABC, Generic[T]):
 
     @abstractmethod
     def __init__(
@@ -106,6 +149,9 @@ class AttentionImpl(ABC):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
     ) -> None:
         raise NotImplementedError
 
@@ -116,7 +162,9 @@ class AttentionImpl(ABC):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: torch.Tensor,
-        attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
-        kv_scale: float,
+        attn_metadata: T,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         raise NotImplementedError

+ 429 - 0
aphrodite/attention/backends/blocksparse_attn.py

@@ -0,0 +1,429 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.ops.blocksparse_attention.interface import (
+    LocalStridedBlockSparseAttn, get_head_sliding_step)
+from aphrodite.attention.ops.paged_attn import PagedAttention
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+
+
+@dataclass
+class BlocksparseParams:
+    max_seqlen: int
+
+    # Num q heads per tensor-parallel rank/partition
+    num_heads: int  # per TP partition
+    # Num kv heads per tensor-parallel rank/partition
+    num_kv_heads: int
+
+    # block size used for blocksparse attention.
+    # This is the block_size used in `local_blocks`, `vert_stride`.
+    block_size: int
+
+    # Number of blocks for local attention, i.e., number of
+    # local attended tokens / `sparse_block_size`
+    local_blocks: int
+
+    # Attend to one block per every `vert_stride` blocks.
+    # Controlling the sparsity
+    vert_stride: int
+    """
+    If to use the same vertical stride offset for all heads, 
+    i.e., attend to the same block of tokens on all heads.
+    By default, it is False, i.e., attention on the non-local 
+    blocks depends on the `head_idx`, that is on
+    blocks satisfying 
+    `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
+    where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
+            `block_idx = position_id // sparse_block_size`.
+    See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
+    for more detail.
+    """
+    homo_head: bool = False
+
+    # If within a group, the kv offsets that each q attends is the same or no.
+    homo_head_group: bool = False
+
+    # Decided by homo_head and homo_head group
+    head_sliding_step: int = field(init=False)
+
+    # range of q heads to for a TP rank
+    active_head_range: Tuple = field(init=False)
+
+    def __post_init__(self):
+        assert self.block_size > 0
+        assert self.local_blocks >= 0
+        assert self.vert_stride >= 1
+        assert self.num_heads % self.num_kv_heads == 0
+
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        total_heads = tp_size * self.num_heads
+        total_kv_heads = tp_size * self.num_kv_heads
+
+        if self.homo_head:
+            self.head_sliding_step = 0
+        elif self.homo_head_group:
+            head_sliding_step = get_head_sliding_step(total_kv_heads,
+                                                      self.vert_stride)
+            # negative indicates sliding along kv heads, i.e., homo q group
+            self.head_sliding_step = -head_sliding_step
+        else:
+            self.head_sliding_step = get_head_sliding_step(
+                total_heads, self.vert_stride)
+
+        self.active_head_range = (
+            tp_rank * self.num_heads,
+            (tp_rank + 1) * self.num_heads,
+        )
+
+
+class BlocksparseFlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
+        return BlocksparseFlashAttentionImpl
+
+    @staticmethod
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return BlocksparseFlashAttentionMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
+        return BlocksparseFlashAttentionMetadataBuilder
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class BlocksparseFlashAttentionMetadata(AttentionMetadata):
+    """A copy of Metadata for FlashAttentionBackend,
+    to avoid having to install flash_attn.
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE(sang): Definition of context_len, query_len, and seq_len.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    _cached_prefill_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+    _cached_decode_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(
+            self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.block_tables is not None
+        assert self.seq_start_loc is not None
+
+        self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=self.block_tables[:self.num_prefills],
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.block_tables is not None
+        assert self.seq_lens_tensor is not None
+
+        self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self.block_tables[self.num_prefills:],
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class BlocksparseFlashAttentionMetadataBuilder(
+        CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
+
+    _metadata_cls = BlocksparseFlashAttentionMetadata
+
+
+class BlocksparseFlashAttentionImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prompt_tokens -------------->|
+    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+    Otherwise, the layout is as follows:
+    |<------------------ num_generation_tokens (M) ----------------->|
+    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+    ) -> None:
+        assert blocksparse_params is not None
+        assert alibi_slopes is None, ValueError(
+            "Alibi not support for blocksparse flash attention.")
+        assert sliding_window is None, ValueError(
+            "sliding_window is invalid for blocksparse attention.")
+        assert logits_soft_cap is None, ValueError(
+            "logits_soft_cap is invalid for blocksparse attention.")
+
+        if "num_heads" not in blocksparse_params:
+            blocksparse_params["num_heads"] = num_heads
+        if "num_kv_heads" not in blocksparse_params:
+            blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
+        self.blocksparse_params = BlocksparseParams(**blocksparse_params)
+        self.kv_cache_dtype = kv_cache_dtype
+
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.alibi_slopes = alibi_slopes
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        self.local_blocks = self.blocksparse_params.local_blocks
+        self.vert_stride = self.blocksparse_params.vert_stride
+        self.sparse_block_size = self.blocksparse_params.block_size
+        self.head_sliding_step = self.blocksparse_params.head_sliding_step
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.tp_rank = get_tensor_model_parallel_rank()
+
+        total_num_heads = num_heads * self.tp_size
+        self.bs_attn = LocalStridedBlockSparseAttn(
+            total_num_heads,
+            self.blocksparse_params.max_seqlen,
+            self.blocksparse_params.local_blocks,
+            self.blocksparse_params.vert_stride,
+            self.blocksparse_params.block_size,
+            homo_head=self.blocksparse_params.homo_head,
+            active_head_range=self.blocksparse_params.active_head_range,
+        )
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: BlocksparseFlashAttentionMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "BlocksparseFlashAttentionImpl")
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+
+            PagedAttention.write_to_paged_cache(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping,
+                self.kv_cache_dtype,
+                k_scale,
+                v_scale,
+            )
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+
+            # Prompt run.
+            # normal attention
+            # When block_tables are not filled, it means q and k are the
+            # prompt, and they have the same length.
+
+            assert kv_cache is None \
+                    or prefill_meta.block_tables is None \
+                    or prefill_meta.block_tables.numel() == 0, \
+                "Does not support prefix-enabled attention."
+
+            output = self.bs_attn(
+                q=query,
+                k=key,
+                v=value,
+                cu_seqlens_q=prefill_meta.seq_start_loc,
+                cu_seqlens_k=prefill_meta.seq_start_loc,
+                sm_scale=self.scale,
+            )
+
+        if decode_meta := attn_metadata.decode_metadata:
+            # Decoding run.
+            output = PagedAttention.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.seq_lens_tensor,
+                self.blocksparse_params.max_seqlen,
+                self.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                k_scale,
+                v_scale,
+                tp_rank=self.tp_rank,
+                blocksparse_local_blocks=self.local_blocks,
+                blocksparse_vert_stride=self.vert_stride,
+                blocksparse_block_size=self.sparse_block_size,
+                blocksparse_head_sliding_step=self.head_sliding_step,
+            )
+
+        # Reshape the output tensor.
+        return output.view(num_tokens, hidden_size)

+ 397 - 97
aphrodite/attention/backends/flash_attn.py

@@ -1,35 +1,48 @@
-"""Attention layer with Flash and PagedAttention.
-NOTE: At the moment, this file includes a lot of duplicated code from
-XFormers backend. The duplicated code will be removed once we use flash-attn or
-flashinfer for all the attention operations.
-"""
+"""Attention layer with FlashAttention."""
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
 
 import torch
-from flash_attn import flash_attn_varlen_func
+from aphrodite_flash_attn import (flash_attn_varlen_func,
+                                  flash_attn_with_kvcache)
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite import _custom_ops as ops
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataBuilder,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
+                                                compute_slot_mapping,
+                                                compute_slot_mapping_start_idx,
+                                                is_block_tables_empty)
+from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
+
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
 
 
 class FlashAttentionBackend(AttentionBackend):
 
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [32, 64, 96, 128, 160, 192, 224, 256]
+
+    @staticmethod
+    def get_name() -> str:
+        return "flash-attn"
+
     @staticmethod
     def get_impl_cls() -> Type["FlashAttentionImpl"]:
         return FlashAttentionImpl
 
     @staticmethod
-    def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
-        return FlashAttentionMetadata(*args, **kwargs)
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return FlashAttentionMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
+        return FlashAttentionMetadataBuilder
 
     @staticmethod
     def get_kv_cache_shape(
@@ -38,89 +51,342 @@ class FlashAttentionBackend(AttentionBackend):
         num_kv_heads: int,
         head_size: int,
     ) -> Tuple[int, ...]:
-        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
-                                                 num_kv_heads, head_size)
+        if block_size % 16 != 0:
+            raise ValueError("Block size must be a multiple of 16.")
+        return (2, num_blocks, block_size, num_kv_heads, head_size)
 
     @staticmethod
     def swap_blocks(
         src_kv_cache: torch.Tensor,
         dst_kv_cache: torch.Tensor,
-        src_to_dst: Dict[int, int],
+        src_to_dst: torch.Tensor,
     ) -> None:
-        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+        src_key_cache = src_kv_cache[0]
+        dst_key_cache = dst_kv_cache[0]
+        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+
+        src_value_cache = src_kv_cache[1]
+        dst_value_cache = dst_kv_cache[1]
+        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
-        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+        key_caches = [kv_cache[0] for kv_cache in kv_caches]
+        value_caches = [kv_cache[1] for kv_cache in kv_caches]
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)
 
 
 @dataclass
-class FlashAttentionMetadata(AttentionMetadataPerStage,
-                             PagedAttentionMetadata):
+class FlashAttentionMetadata(AttentionMetadata):
     """Metadata for FlashAttentionBackend.
+
     NOTE: Any python object stored here is not updated when it is
     cuda-graph replayed. If you have values that need to be changed
     dynamically, it should be stored in tensor. The tensor has to be
     updated from `CUDAGraphRunner.forward` API.
     """
-    # Currently, input sequences can only contain all prompts
-    # or all decoding. True if all sequences are prompts.
-    is_prompt: bool
-    # (batch_size,). The prompt length per sequence. None if it is a decoding.
-    prompt_lens: Optional[List[int]]
-    # prompt_lens stored as a tensor.
-    prompt_lens_tensor: Optional[torch.Tensor]
-
-    # NOTE: Definition of context_len, subquery_len, and seqlen.
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE: Definition of context_len, query_len, and seq_len.
     # |---------- N-1 iteration --------|
     # |---------------- N iteration ---------------------|
     # |- tokenA -|......................|-- newTokens ---|
     # |---------- context_len ----------|
-    # |-------------------- seqlen ----------------------|
-    #                                   |- subquery_len -|
-
-    # WARNING: context_len has different definition depending on if it is
-    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
-    # When it is for decoding, it includes a new token.
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
 
-    # Maximum subquery length in the batch.
-    max_subquery_len: Optional[int]
-    # Maximum prompt length in the batch.
-    max_prompt_len: Optional[int]
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
     # (batch_size + 1,). The cumulative subquery lengths of the sequences in
     # the batch, used to index into subquery. E.g., if the subquery length
     # is [4, 6], it is [0, 4, 10].
-    subquery_start_loc: Optional[torch.Tensor]
+    query_start_loc: Optional[torch.Tensor]
     # (batch_size + 1,). The cumulative sequence lengths of the sequences in
     # the batch, used to index into sequence. E.g., if the sequence length is
     # [4, 6], it is [0, 4, 10].
     seq_start_loc: Optional[torch.Tensor]
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
 
     # Whether or not if cuda graph is enabled.
     # Cuda-graph is currently enabled for decoding only.
     # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
     use_cuda_graph: bool
 
+    _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
+    _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.block_tables is not None
+        assert self.seq_start_loc is not None
+
+        self._cached_prefill_metadata = FlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=self.block_tables[:self.num_prefills],
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.block_tables is not None
+        assert self.seq_lens_tensor is not None
+
+        self._cached_decode_metadata = FlashAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self.block_tables[self.num_prefills:],
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class FlashAttentionMetadataBuilder(
+        AttentionMetadataBuilder[FlashAttentionMetadata]):
+
+    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
+        self.slot_mapping: List[int] = []
+        self.prefill_seq_lens: List[int] = []
+        self.context_lens: List[int] = []
+        self.block_tables: List[List[int]] = []
+        self.curr_seq_lens: List[int] = []
+        self.num_prefills = 0
+        self.num_prefill_tokens = 0
+        self.num_decode_tokens = 0
+        self.has_prefix_cache_hit = False
+
+        self.input_builder = input_builder
+        self.runner = input_builder.runner
+        self.sliding_window = input_builder.sliding_window
+        self.block_size = input_builder.block_size
+        self.use_v2_block_manager = (
+            input_builder.scheduler_config.use_v2_block_manager)
+
+    def _add_seq_group(
+            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
+            chunked_prefill_enabled: bool, prefix_cache_hit: bool):
+        """Add a sequence group to the metadata. Specifically update/append
+        1. context length.
+        2. block table.
+        3. slot mapping.
+        """
+        is_prompt = inter_data.is_prompt
+        block_tables = inter_data.block_tables
+
+        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
+             curr_sliding_window_block) in zip(
+                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
+                 inter_data.orig_seq_lens, inter_data.seq_lens,
+                 inter_data.query_lens, inter_data.context_lens,
+                 inter_data.curr_sliding_window_blocks):
+            self.context_lens.append(context_len)
+
+            if is_prompt:
+                self.num_prefills += 1
+                self.num_prefill_tokens += token_len
+                self.prefill_seq_lens.append(seq_len)
+            else:
+                assert query_len == 1, (
+                    "seq_len: {}, context_len: {}, query_len: {}".format(
+                        seq_len, context_len, query_len))
+                self.num_decode_tokens += query_len
+                self.curr_seq_lens.append(curr_seq_len)
+
+            # Compute block table.
+            # TODO: Combine chunked prefill and prefix caching by
+            # only allowing multiple of block_size chunk size.
+            # NOTE: This only works for oooooooxxx style attention.
+            block_table = []
+            if prefix_cache_hit:
+                # NOTE: For flash-attn, the block table should
+                # include the entries for the incoming prefill tokens.
+                block_table = block_tables[seq_id]
+            elif ((chunked_prefill_enabled or not is_prompt)
+                  and block_tables is not None):
+                block_table = block_tables[seq_id][-curr_sliding_window_block:]
+            self.block_tables.append(block_table)
+
+            # Compute slot mapping.
+            is_profile_run = is_block_tables_empty(block_tables)
+            start_idx = compute_slot_mapping_start_idx(
+                is_prompt, query_len, context_len, self.sliding_window,
+                self.use_v2_block_manager)
+            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
+                                 seq_len, context_len, start_idx,
+                                 self.block_size, inter_data.block_tables)
+
+    def build(self, seq_lens: List[int], query_lens: List[int],
+              cuda_graph_pad_size: int, batch_size: int):
+        """Build attention metadata with on-device tensors.
+        Args:
+            seq_lens: The maybe padded sequence lengths of the input sequences.
+            query_lens: The query lengths of the input sequences.
+            cuda_graph_pad_size: The padding size for cuda graph.
+                                 -1 if cuda graph is not used.
+            batch_size: The maybe padded batch size.
+        """
+        prefix_cache_hit = any([
+            inter_data.prefix_cache_hit
+            for inter_data in self.input_builder.inter_data_list
+        ])
+        for inter_data in self.input_builder.inter_data_list:
+            self._add_seq_group(inter_data,
+                                self.input_builder.chunked_prefill_enabled,
+                                prefix_cache_hit)
+
+        device = self.runner.device
+        use_captured_graph = cuda_graph_pad_size != -1
+
+        max_query_len = max(query_lens)
+        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
+        max_decode_seq_len = max(self.curr_seq_lens, default=0)
+        num_decode_tokens = self.num_decode_tokens
+
+        if use_captured_graph:
+            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
+            self.block_tables.extend([] * cuda_graph_pad_size)
+            num_decode_tokens = batch_size
+
+            # The shape of graph_block_tables is
+            # [max batch size, max context len // block size].
+            input_block_tables = self.runner.graph_block_tables[:batch_size]
+            for i, block_table in enumerate(self.block_tables):
+                if block_table:
+                    input_block_tables[i, :len(block_table)] = block_table
+            block_tables = torch.from_numpy(input_block_tables).to(
+                device=device, non_blocking=True)
+        else:
+            block_tables = make_tensor_with_pad(
+                self.block_tables,
+                pad=0,
+                dtype=torch.int,
+                device=device,
+            )
+        assert max_query_len > 0, ("query_lens: {}".format(query_lens))
+
+        assert device is not None
+        context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
+                                               device, self.runner.pin_memory)
+        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
+                                           self.runner.pin_memory)
+        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
+                                             self.runner.pin_memory)
+        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
+                                               device, self.runner.pin_memory)
+        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
+                                      dtype=torch.int32,
+                                      device=device)
+        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
+                                    dtype=torch.int32,
+                                    device=device)
+        torch.cumsum(seq_lens_tensor,
+                     dim=0,
+                     dtype=seq_start_loc.dtype,
+                     out=seq_start_loc[1:])
+        torch.cumsum(query_lens_tensor,
+                     dim=0,
+                     dtype=query_start_loc.dtype,
+                     out=query_start_loc[1:])
+
+        return FlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            slot_mapping=slot_mapping_tensor,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=num_decode_tokens,
+            seq_lens=seq_lens,
+            seq_lens_tensor=seq_lens_tensor,
+            max_query_len=max_query_len,
+            max_prefill_seq_len=max_prefill_seq_len,
+            max_decode_seq_len=max_decode_seq_len,
+            query_start_loc=query_start_loc,
+            seq_start_loc=seq_start_loc,
+            context_lens_tensor=context_lens_tensor,
+            block_tables=block_tables,
+            use_cuda_graph=use_captured_graph,
+        )
+
 
 class FlashAttentionImpl(AttentionImpl):
     """
     If the input tensors contain prompt tokens, the layout is as follows:
     |<--------------- num_prefill_tokens ----------------->|	
     |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
+
     Otherwise, the layout is as follows:	
     |<----------------- num_decode_tokens ------------------>|	
     |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
+
     Generation tokens can contain padding when cuda-graph is used.
     Currently, prompt tokens don't contain any padding.
+
     The prompts might have different lengths, while the generation tokens
     always have length 1.
+
     If chunked prefill is enabled, prefill tokens and decode tokens can be
     batched together in a flattened 1D query.
+
     |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
     |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+
     Currently, cuda graph is disabled for chunked prefill, meaning there's no
     padding between prefill and decode tokens.
     """
@@ -130,28 +396,45 @@ class FlashAttentionImpl(AttentionImpl):
         num_heads: int,
         head_size: int,
         scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
     ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "FlashAttention does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = ((sliding_window, sliding_window)
-                               if sliding_window is not None else (-1, -1))
+        self.num_kv_heads = num_kv_heads
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        self.kv_cache_dtype = kv_cache_dtype
+        if logits_soft_cap is None:
+            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
+            logits_soft_cap = 0
+        self.logits_soft_cap = logits_soft_cap
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        if sliding_window is not None:
+            # NOTE: flash-attn's sliding window does not work with
+            # paged KV cache.
+            raise ValueError(
+                "Sliding window is not supported in FlashAttention.")
+
+        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
+        if head_size not in support_head_sizes:
             raise ValueError(
-                f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Head size {head_size} is not supported by FlashAttention. "
+                f"Supported head sizes are: {support_head_sizes}.")
 
     def forward(
         self,
@@ -159,19 +442,31 @@ class FlashAttentionImpl(AttentionImpl):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: torch.Tensor,
-        attn_metadata: AttentionMetadata[FlashAttentionMetadata],
-        kv_scale: float,
+        attn_metadata: FlashAttentionMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
-        """Forward pass with FlashAttention and PagedAttention.
+        """Forward pass with FlashAttention.
+
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
             value: shape = [num_tokens, num_kv_heads * head_size]
-            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
             attn_metadata: Metadata for attention.
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "FlashAttentionImpl")
+        # NOTE: FlashAttention does not support FP8 KV cache.
+        assert k_scale == 1.0 and v_scale == 1.0, (
+            "key/v_scale is not supported in FlashAttention.")
+
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -179,17 +474,22 @@ class FlashAttentionImpl(AttentionImpl):
         value = value.view(-1, self.num_kv_heads, self.head_size)
 
         if kv_cache is not None:
-            key_cache, value_cache = PagedAttention.split_kv_cache(
-                kv_cache, self.num_kv_heads, self.head_size)
+            key_cache = kv_cache[0]
+            value_cache = kv_cache[1]
 
             # Reshape the input keys and values and store them in the cache.
             # If kv_cache is not provided, the new key and value tensors are
             # not cached. This happens during the initial memory profiling run.
-            PagedAttention.write_to_paged_cache(key, value, key_cache,
-                                                value_cache,
-                                                attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
-                                                kv_scale)
+            ops.reshape_and_cache_flash(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping.flatten(),
+                self.kv_cache_dtype,
+                k_scale,
+                v_scale,
+            )
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
         num_decode_tokens = attn_metadata.num_decode_tokens
@@ -209,7 +509,8 @@ class FlashAttentionImpl(AttentionImpl):
 
         if prefill_meta := attn_metadata.prefill_metadata:
             # Prompt run.
-            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
+            if (kv_cache is None or prefill_meta.block_tables is None
+                    or prefill_meta.block_tables.numel() == 0):
                 # normal attention
                 # When block_tables are not filled, it means q and k are the
                 # prompt, and they have the same length.
@@ -219,48 +520,47 @@ class FlashAttentionImpl(AttentionImpl):
                     v=value,
                     cu_seqlens_q=prefill_meta.seq_start_loc,
                     cu_seqlens_k=prefill_meta.seq_start_loc,
-                    max_seqlen_q=prefill_meta.max_prompt_len,
-                    max_seqlen_k=prefill_meta.max_prompt_len,
+                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
+                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
                     softmax_scale=self.scale,
                     causal=True,
                     window_size=self.sliding_window,
                     alibi_slopes=self.alibi_slopes,
+                    softcap=self.logits_soft_cap,
                 )
                 assert output[:num_prefill_tokens].shape == out.shape
                 output[:num_prefill_tokens] = out
             else:
                 # prefix-enabled attention
-                # TODO: this triton kernel has regression issue (broke) to
-                # deal with different data types between KV and FP8 KV cache,
-                # to be addressed separately.
-                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
-                    query,
-                    key,
-                    value,
-                    key_cache,
-                    value_cache,
-                    prefill_meta.block_tables,
-                    prefill_meta.subquery_start_loc,
-                    prefill_meta.prompt_lens_tensor,
-                    prefill_meta.context_lens,
-                    prefill_meta.max_subquery_len,
-                    self.alibi_slopes,
+                assert prefill_meta.seq_lens is not None
+                max_seq_len = max(prefill_meta.seq_lens)
+                output[:num_prefill_tokens] = flash_attn_varlen_func(
+                    q=query,
+                    k=key_cache,
+                    v=value_cache,
+                    cu_seqlens_q=prefill_meta.query_start_loc,
+                    max_seqlen_q=prefill_meta.max_query_len,
+                    cu_seqlens_k=prefill_meta.seq_start_loc,
+                    max_seqlen_k=max_seq_len,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    alibi_slopes=self.alibi_slopes,
+                    block_table=prefill_meta.block_tables,
+                    softcap=self.logits_soft_cap,
                 )
+
         if decode_meta := attn_metadata.decode_metadata:
             # Decoding run.
-            output[num_prefill_tokens:] = PagedAttention.forward_decode(
-                decode_query,
+            output[num_prefill_tokens:] = flash_attn_with_kvcache(
+                decode_query.unsqueeze(1),
                 key_cache,
                 value_cache,
-                decode_meta.block_tables,
-                decode_meta.context_lens,
-                decode_meta.max_context_len,
-                attn_metadata.kv_cache_dtype,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-                kv_scale,
-            )
+                block_table=decode_meta.block_tables,
+                cache_seqlens=decode_meta.seq_lens_tensor,
+                softmax_scale=self.scale,
+                causal=True,
+                alibi_slopes=self.alibi_slopes,
+            ).squeeze(1)
 
         # Reshape the output tensor.
         return output.view(num_tokens, hidden_size)

+ 550 - 0
aphrodite/attention/backends/flashinfer.py

@@ -0,0 +1,550 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
+
+try:
+    from aphrodite_flash_attn import flash_attn_varlen_func
+    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
+    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
+except ImportError:
+    flash_attn_varlen_func = None
+    BatchDecodeWithPagedKVCacheWrapper = None
+    BatchPrefillWithPagedKVCacheWrapper = None
+
+import torch
+
+from aphrodite import _custom_ops as ops
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataBuilder,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
+                                                compute_slot_mapping,
+                                                compute_slot_mapping_start_idx,
+                                                is_block_tables_empty)
+from aphrodite.attention.ops.paged_attn import PagedAttention
+from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
+                                    make_tensor_with_pad)
+
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+
+
+class FlashInferBackend(AttentionBackend):
+
+    @staticmethod
+    def get_name() -> str:
+        return "flashinfer"
+
+    @staticmethod
+    def get_impl_cls() -> Type["FlashInferImpl"]:
+        return FlashInferImpl
+
+    @staticmethod
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return FlashInferMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
+        return FlashInferMetadataBuilder
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return (num_blocks, 2, block_size, num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: torch.Tensor,
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: torch.Tensor,
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [64, 128, 256]
+
+
+@dataclass
+class FlashInferMetadata(AttentionMetadata):
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+
+    use_cuda_graph: bool = True
+
+    prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
+    decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
+
+    # Metadata for the prefill stage
+    seq_start_loc: Optional[torch.Tensor] = None
+    query_start_loc: Optional[torch.Tensor] = None
+    block_tables: Optional[torch.Tensor] = None
+
+    # An example for paged_kv_indices, paged_kv_indptr:
+    # request 1, page indices [0, 5, 8]
+    # request 2, page indices [1, 6, 7]
+    # request 3, page indices [3, 4]
+    # paged_kv_indices is a concatenation of page indices of all requests:
+    # [0, 5, 8, 1, 6, 7, 3, 4]
+    # paged_kv_indptr is used to index into paged_kv_indices:
+    # [0, 3, 6, 8]
+    # The indptr of the paged kv cache, shape: [batch_size + 1]
+    paged_kv_indptr: Optional[torch.Tensor] = None
+    # The page indices of the paged kv cache
+    paged_kv_indices: Optional[torch.Tensor] = None
+    # The number of entries in the last page of each request in
+    # the paged kv cache, shape: [batch_size]
+    paged_kv_last_page_len: Optional[torch.Tensor] = None
+    # The number of query/output heads
+    num_qo_heads: Optional[int] = None
+    # The number of key/value heads
+    num_kv_heads: Optional[int] = None
+    # The dimension of the attention heads
+    head_dim: Optional[int] = None
+    # Block size of Aphrodite
+    page_size: Optional[int] = None
+    # The data type of the paged kv cache
+    data_type: torch.dtype = None
+    device: torch.device = torch.device("cuda")
+
+    def __post_init__(self):
+        # Refer to
+        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
+        supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
+        if self.head_dim is not None and self.head_dim \
+                not in supported_head_sizes:
+            raise ValueError(
+                f"Only {supported_head_sizes} are supported for head_dim,",
+                f"received {self.head_dim}.")
+
+    def begin_forward(self):
+        if self.num_prefill_tokens > 0:
+            if self.paged_kv_indices is None:
+                return
+
+            assert self.prefill_wrapper is not None
+            assert self.query_start_loc is not None
+            assert self.paged_kv_indices is not None
+            assert self.paged_kv_indptr is not None
+            assert self.paged_kv_last_page_len is not None
+            batch_size = self.query_start_loc.shape[0] - 1
+            assert batch_size >= 0
+            # The prefill stage does not read kv cache.
+            # Both paged_kv_indices and paged_kv_last_page_len are empty.
+            # paged_kv_indptr is a zero tensor with size batch_size + 1.
+            self.paged_kv_indptr = torch.zeros(batch_size + 1,
+                                               device=self.device)
+            self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
+                self.device)
+            self.paged_kv_indices = self.paged_kv_indices.to(self.device)
+            self.prefill_wrapper.end_forward()
+            self.prefill_wrapper.begin_forward(
+                self.query_start_loc, self.paged_kv_indptr,
+                self.paged_kv_indices, self.paged_kv_last_page_len,
+                self.num_qo_heads, self.num_kv_heads, self.head_dim,
+                self.page_size)
+        else:
+            if not self.use_cuda_graph:
+                assert self.paged_kv_indices is not None
+                assert self.paged_kv_indptr is not None
+                assert self.paged_kv_last_page_len is not None
+                self.paged_kv_indices = self.paged_kv_indices.to(self.device)
+                self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
+                self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
+                    self.device)
+
+            assert self.decode_wrapper is not None
+            self.decode_wrapper.end_forward()
+            self.decode_wrapper.begin_forward(
+                self.paged_kv_indptr,
+                self.paged_kv_indices,
+                self.paged_kv_last_page_len,
+                self.num_qo_heads,
+                self.num_kv_heads,
+                self.head_dim,
+                self.page_size,
+                # Disable flashinfer's pos encoding and use Aphrodite's rope.
+                pos_encoding_mode="NONE",
+                data_type=self.data_type)
+
+    def asdict_zerocopy(self,
+                        skip_fields: Optional[Set[str]] = None
+                        ) -> Dict[str, Any]:
+        if skip_fields is None:
+            skip_fields = set()
+        # We need to skip the prefill/decode_wrapper field since it cannot be
+        # broadcasted with nccl when TP is enabled.
+        skip_fields.add('prefill_wrapper')
+        skip_fields.add('decode_wrapper')
+        return super().asdict_zerocopy(skip_fields)
+
+    @property
+    def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_decode_tokens == 0:
+            assert self.num_prefills > 0
+            return self
+
+        return None
+
+    @property
+    def decode_metadata(self) -> Optional["FlashInferMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_prefills > 0:
+            assert self.num_decode_tokens == 0
+            return None
+
+        return self
+
+
+class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
+
+    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
+        self.slot_mapping: List[int] = []
+        self.prefill_seq_lens: List[int] = []
+        self.context_lens: List[int] = []
+        self.block_tables: List[List[int]] = []
+        self.curr_seq_lens: List[int] = []
+        self.num_prefills = 0
+        self.num_prefill_tokens = 0
+        self.num_decode_tokens = 0
+
+        self.input_builder = input_builder
+        self.runner = input_builder.runner
+
+        self.sliding_window = input_builder.sliding_window
+        self.block_size = input_builder.block_size
+        self.use_v2_block_manager = (
+            input_builder.scheduler_config.use_v2_block_manager)
+
+        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
+        # for the precise definition of the following fields.
+        # An example:
+        # request 1, page indices [0, 5, 8]
+        # request 2, page indices [1, 6, 7]
+        # request 3, page indices [3, 4]
+        # paged_kv_indices is a concatenation of page indices of all requests:
+        # [0, 5, 8, 1, 6, 7, 3, 4]
+        # paged_kv_indptr is used to index into paged_kv_indices:
+        # [0, 3, 6, 8]
+        self.paged_kv_indices: List[int] = []
+        # 0 at the beginning of paged_kv_indptr indicates the start of the
+        # first request’s page indices in the paged_kv_indices list.
+        self.paged_kv_indptr: List[int] = [0]
+        # paged_kv_last_page_len is the length of the last page of each request
+        self.paged_kv_last_page_len: List[int] = []
+
+    def _add_seq_group(
+            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
+            chunked_prefill_enabled: bool):
+        """Add a sequence group to the metadata. Specifically update/append
+        1. context length.
+        2. block table.
+        3. slot mapping.
+        """
+        is_prompt = inter_data.is_prompt
+        block_tables = inter_data.block_tables
+        computed_block_nums = inter_data.computed_block_nums
+
+        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
+             curr_sliding_window_block) in zip(
+                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
+                 inter_data.orig_seq_lens, inter_data.seq_lens,
+                 inter_data.query_lens, inter_data.context_lens,
+                 inter_data.curr_sliding_window_blocks):
+            self.context_lens.append(context_len)
+            if is_prompt:
+                self.num_prefills += 1
+                self.num_prefill_tokens += token_len
+                self.prefill_seq_lens.append(seq_len)
+            else:
+                assert query_len == 1, (
+                    "seq_len: {}, context_len: {}, query_len: {}".format(
+                        seq_len, context_len, query_len))
+                self.num_decode_tokens += query_len
+                self.curr_seq_lens.append(curr_seq_len)
+
+            # Compute block table.
+            # TODO: Combine chunked prefill and prefix caching by
+            # only allowing multiple of block_size chunk size.
+            # NOTE: This only works for oooooooxxx style attention.
+            block_table = []
+            if inter_data.prefix_cache_hit:
+                block_table = computed_block_nums
+            elif ((chunked_prefill_enabled or not is_prompt)
+                  and block_tables is not None):
+                block_table = block_tables[seq_id][-curr_sliding_window_block:]
+            self.block_tables.append(block_table)
+
+            is_profile_run = is_block_tables_empty(block_tables)
+
+            # Compute slot mapping.
+            start_idx = compute_slot_mapping_start_idx(
+                is_prompt, query_len, context_len, self.sliding_window,
+                self.use_v2_block_manager)
+            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
+                                 seq_len, context_len, start_idx,
+                                 self.block_size, inter_data.block_tables)
+
+            # It is not necessary to add paged_kv_indices, paged_kv_indptr,
+            # and paged_kv_last_page_len for profile run because we will
+            # create dummy inputs.
+            if is_profile_run:
+                return
+
+            block_table = block_tables[seq_id]
+            self._update_paged_kv_tensors(block_table, seq_len)
+
+    def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
+        # Get the number of valid blocks based on sequence length.
+        # If seq_len = 16, block_size = 16,
+        # block_table_bound is 1 with 1 valid block.
+        # If seq_len = 15, block_size = 16,
+        # block_table_bound is 0 + 1 with 1 valid block.
+        block_table_bound = seq_len // self.block_size + 1 \
+                            if seq_len % self.block_size != 0 \
+                            else seq_len // self.block_size
+        self.paged_kv_indices.extend(block_table[:block_table_bound])
+        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
+                                    block_table_bound)
+
+        last_page_len = seq_len % self.block_size
+        if last_page_len == 0:
+            last_page_len = self.block_size
+        self.paged_kv_last_page_len.append(last_page_len)
+
+    def build(self, seq_lens: List[int], query_lens: List[int],
+              cuda_graph_pad_size: int, batch_size: int):
+        """Build attention metadata with on-device tensors.
+        Args:
+            seq_lens: The maybe padded sequence lengths of the input sequences.
+            query_lens: The query lengths of the input sequences.
+            cuda_graph_pad_size: The padding size for cuda graph.
+                                 -1 if cuda graph is not used.
+            batch_size: The maybe padded batch size.
+        """
+        for inter_data in self.input_builder.inter_data_list:
+            self._add_seq_group(inter_data,
+                                self.input_builder.chunked_prefill_enabled)
+
+        device = self.runner.device
+        use_captured_graph = cuda_graph_pad_size != -1
+
+        max_query_len = max(query_lens)
+        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
+        num_decode_tokens = self.num_decode_tokens
+
+        if use_captured_graph:
+            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
+            self.block_tables.extend([] * cuda_graph_pad_size)
+            num_decode_tokens = batch_size
+
+            # The shape of graph_block_tables is
+            # [max batch size, max context len // block size].
+            input_block_tables = self.runner.graph_block_tables[:batch_size]
+            for i, block_table in enumerate(self.block_tables):
+                if block_table:
+                    input_block_tables[i, :len(block_table)] = block_table
+            block_tables = torch.from_numpy(input_block_tables).to(
+                device, non_blocking=True)
+
+            last_paged_kv_indptr = self.paged_kv_indptr[-1]
+            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
+                                        cuda_graph_pad_size)
+            self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
+        else:
+            block_tables = make_tensor_with_pad(
+                self.block_tables,
+                pad=0,
+                dtype=torch.int,
+                device=device,
+            )
+        assert max_query_len > 0, ("query_lens: {}".format(query_lens))
+
+        assert device is not None
+        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
+                                           self.runner.pin_memory)
+        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
+                                             self.runner.pin_memory)
+        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
+                                               device, self.runner.pin_memory)
+        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
+                                      dtype=torch.int32,
+                                      device=device)
+        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
+                                    dtype=torch.int32,
+                                    device=device)
+        torch.cumsum(seq_lens_tensor,
+                     dim=0,
+                     dtype=seq_start_loc.dtype,
+                     out=seq_start_loc[1:])
+        torch.cumsum(query_lens_tensor,
+                     dim=0,
+                     dtype=query_start_loc.dtype,
+                     out=query_start_loc[1:])
+
+        if len(self.paged_kv_indptr) > 0:
+            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
+                                                   device="cpu",
+                                                   dtype=torch.int)
+            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
+                                                  device="cpu",
+                                                  dtype=torch.int)
+            paged_kv_last_page_len_tensor = torch.tensor(
+                self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
+        else:
+            paged_kv_indices_tensor = None
+            paged_kv_indptr_tensor = None
+            paged_kv_last_page_len_tensor = None
+
+        kv_cache_dtype = get_kv_cache_torch_dtype(
+            self.runner.kv_cache_dtype, self.runner.model_config.dtype)
+        return FlashInferMetadata(
+            num_prefills=self.num_prefills,
+            slot_mapping=slot_mapping_tensor,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=num_decode_tokens,
+            max_prefill_seq_len=max_prefill_seq_len,
+            block_tables=block_tables,
+            paged_kv_indptr=paged_kv_indptr_tensor,
+            paged_kv_indices=paged_kv_indices_tensor,
+            paged_kv_last_page_len=paged_kv_last_page_len_tensor,
+            num_qo_heads=self.runner.model_config.get_num_attention_heads(
+                self.runner.parallel_config),
+            num_kv_heads=self.runner.model_config.get_num_kv_heads(
+                self.runner.parallel_config),
+            head_dim=self.runner.model_config.get_head_size(),
+            page_size=self.block_size,
+            seq_start_loc=seq_start_loc,
+            query_start_loc=query_start_loc,
+            device=device,
+            data_type=kv_cache_dtype,
+            use_cuda_graph=use_captured_graph)
+
+
+class FlashInferImpl(AttentionImpl):
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+    ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "FlashInfer does not support block-sparse attention.")
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_kv_heads
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+        if sliding_window is not None:
+            raise ValueError("Sliding window is not supported in FlashInfer.")
+        self.sliding_window = (-1, -1)
+        self.kv_cache_dtype = kv_cache_dtype
+        self.logits_soft_cap = logits_soft_cap
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: FlashInferMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
+    ) -> torch.Tensor:
+        assert k_scale == 1.0 and v_scale == 1.0, (
+            "key/v_scale is not supported in FlashInfer.")
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "FlashInferImpl")
+        num_tokens, hidden_size = query.shape
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if attn_metadata.num_prefill_tokens > 0:
+            assert attn_metadata.num_decode_tokens == 0, (
+                "Chunked prefill is not supported with flashinfer yet.")
+        if attn_metadata.num_decode_tokens > 0:
+            assert attn_metadata.num_prefill_tokens == 0, (
+                "Chunked prefill is not supported with flashinfer yet.")
+
+        if kv_cache is not None:
+            # Use the same reshape and cache kernel as flash attention.
+            ops.reshape_and_cache_flash(
+                key,
+                value,
+                kv_cache[:, 0],
+                kv_cache[:, 1],
+                attn_metadata.slot_mapping.flatten(),
+                self.kv_cache_dtype,
+                k_scale,
+                v_scale,
+            )
+
+        query = query.contiguous(
+        )  # Flashinfer requires query to be contiguous
+        if prefill_meta := attn_metadata.prefill_metadata:
+            # We will use flash attention for prefill
+            # when kv_cache is not provided.
+            # This happens when Aphrodite runs the profiling to
+            # determine the number of blocks.
+            if kv_cache is None:
+                output = flash_attn_varlen_func(
+                    q=query,
+                    k=key,
+                    v=value,
+                    cu_seqlens_q=prefill_meta.seq_start_loc,
+                    cu_seqlens_k=prefill_meta.seq_start_loc,
+                    max_seqlen_q=prefill_meta.max_prefill_seq_len,
+                    max_seqlen_k=prefill_meta.max_prefill_seq_len,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    window_size=self.sliding_window,
+                    alibi_slopes=self.alibi_slopes,
+                )
+            else:
+                assert prefill_meta is not None
+                assert prefill_meta.prefill_wrapper is not None
+                output = prefill_meta.prefill_wrapper.forward(
+                    query,
+                    kv_cache,
+                    logits_soft_cap=self.logits_soft_cap,
+                    causal=True)
+        else:
+            assert attn_metadata.decode_metadata is not None
+            assert attn_metadata.decode_metadata.decode_wrapper is not None
+            output = attn_metadata.decode_metadata.decode_wrapper.forward(
+                query,
+                kv_cache,
+                sm_scale=self.scale,
+                logits_soft_cap=self.logits_soft_cap)
+        return output.view(num_tokens, hidden_size)

+ 380 - 0
aphrodite/attention/backends/ipex_attn.py

@@ -0,0 +1,380 @@
+""" Attention layer with torch scaled_dot_product_attention
+    and PagedAttention."""
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+
+from aphrodite._ipex_ops import ipex_ops
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
+
+_PARTITION_SIZE = 512
+
+
+class IpexAttnBackend(AttentionBackend):
+
+    @staticmethod
+    def get_name() -> str:
+        return "ipex-attn"
+
+    @staticmethod
+    def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
+        return IpexAttnBackendImpl
+
+    @staticmethod
+    def get_metadata_cls() -> Type["IpexAttnMetadata"]:
+        return IpexAttnMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["IpexAttnMetadataBuilder"]:
+        return IpexAttnMetadataBuilder
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: torch.Tensor,
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: torch.Tensor,
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
+    """Metadata for IpexAttnBackend.
+    """
+    # Currently, input sequences can only contain all prompts
+    # or all decoding. True if all sequences are prompts.
+    is_prompt: bool
+    slot_mapping: torch.Tensor
+    seq_lens: Optional[List[int]]
+    seqlen_q: Optional[torch.Tensor]
+    max_seqlen: Optional[int]
+
+    def __post_init__(self):
+        # Set during the execution of the first attention op.
+        # It is a list because it is needed to set per prompt
+        # when alibi slopes is used. It is because of the limitation
+        # from xformer API.
+        # will not appear in the __repr__ and __init__
+        self.attn_bias: Optional[List[torch.Tensor]] = None
+
+    @property
+    def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_decode_tokens == 0:
+            assert self.num_prefills > 0
+            return self
+
+        return None
+
+    @property
+    def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_prefills > 0:
+            assert self.num_decode_tokens == 0
+            return None
+
+        return self
+
+
+class IpexAttnMetadataBuilder(CommonMetadataBuilder[IpexAttnMetadata]):
+
+    _metadata_cls = IpexAttnMetadata
+
+
+class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+    ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "IPEX backend does not support block-sparse attention.")
+        if logits_soft_cap is not None:
+            raise ValueError("IPEX backend does not support logits_soft_cap.")
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_kv_heads
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+        self.sliding_window = sliding_window
+        self.kv_cache_dtype = kv_cache_dtype
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        self.need_mask = (self.alibi_slopes is not None
+                          or self.sliding_window is not None)
+
+        supported_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in supported_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {supported_head_sizes}.")
+        if kv_cache_dtype != "auto":
+            raise NotImplementedError(
+                "IPEX backend does not support FP8 KV cache. "
+                "Please use xFormers backend instead.")
+
+    def split_kv_cache(
+        self,
+        kv_cache: torch.Tensor,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        x = 1
+        num_blocks = kv_cache.shape[1]
+
+        key_cache = kv_cache[0]
+        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
+                                   -1, x)
+        value_cache = kv_cache[1]
+        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
+        return key_cache, value_cache
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: IpexAttnMetadata,  # type: ignore
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
+    ) -> torch.Tensor:
+        """Forward pass with IPEX varlen_attention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        assert k_scale == 1.0 and v_scale == 1.0
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "IpexAttnBackendImpl")
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = self.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+            ipex_ops.reshape_and_cache(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping.flatten(),
+                self.kv_cache_dtype,
+                k_scale,
+                v_scale,
+            )
+
+        if attn_metadata.is_prompt:
+            assert attn_metadata.seq_lens is not None
+            if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
+                if self.num_kv_heads != self.num_heads:
+                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
+                    value = value.repeat_interleave(self.num_queries_per_kv,
+                                                    dim=1)
+
+                if attn_metadata.attn_bias is None:
+                    if self.alibi_slopes is not None:
+                        att_masks = _make_alibi_bias(
+                            self.alibi_slopes, query.dtype,
+                            attn_metadata.seq_lens)  # type: ignore
+                    elif self.sliding_window is not None:
+                        att_masks = _make_sliding_window_bias(
+                            attn_metadata.seq_lens, self.sliding_window,
+                            query.dtype)  # type: ignore
+                    else:
+                        att_masks = _make_sliding_window_bias(
+                            attn_metadata.seq_lens, None, dtype=query.dtype)
+                    attn_metadata.attn_bias = att_masks
+
+                output = torch.empty(
+                    (num_tokens, self.num_heads, self.head_size),
+                    dtype=query.dtype,
+                    device=query.device)
+                ipex_ops.varlen_attention(query,
+                                          key,
+                                          value,
+                                          output,
+                                          attn_metadata.seqlen_q,
+                                          attn_metadata.seqlen_q,
+                                          attn_metadata.max_seqlen,
+                                          attn_metadata.max_seqlen,
+                                          pdropout=0.0,
+                                          softmax_scale=self.scale,
+                                          zero_tensors=False,
+                                          is_causal=True,
+                                          return_softmax=False,
+                                          gen_=None)
+            else:
+                # prefix-enabled attention
+                raise RuntimeError(
+                    "IPEX backend doesn't support prefix decoding.")
+
+        else:
+            # Decoding run.
+            max_seq_len = attn_metadata.max_decode_seq_len
+            output = torch.empty_like(query)
+            block_size = value_cache.shape[3]
+            num_seqs, num_heads, head_size = query.shape
+            max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
+                                  _PARTITION_SIZE)
+            # NOTE: We use a simple heuristic to decide whether to use
+            # PagedAttention V1 or V2. If the number of partitions is 1, we use
+            # V1 to avoid the overhead of reduction. Also, if the number of
+            # sequences or heads is large, we use V1 since there is enough work
+            # to parallelize.
+            # TODO: Tune this heuristic.
+            # For context len > 8192, use V2 kernel to avoid shared memory
+            # shortage.
+            use_v1 = (max_seq_len <= 8192 and
+                      (max_num_partitions == 1 or num_seqs * num_heads > 512))
+            if use_v1:
+                # Run PagedAttention V1.
+                ipex_ops.paged_attention_v1(
+                    output,
+                    query,
+                    key_cache,
+                    value_cache,
+                    self.num_kv_heads,
+                    self.scale,
+                    attn_metadata.block_tables,
+                    attn_metadata.seq_lens_tensor,
+                    block_size,
+                    max_seq_len,
+                    self.alibi_slopes,
+                    self.kv_cache_dtype,
+                    k_scale,
+                    v_scale,
+                )
+            else:
+                # Run PagedAttention V2.
+                assert _PARTITION_SIZE % block_size == 0
+                tmp_output = torch.empty(
+                    size=(num_seqs, num_heads, max_num_partitions, head_size),
+                    dtype=output.dtype,
+                    device=output.device,
+                )
+                exp_sums = torch.empty(
+                    size=(num_seqs, num_heads, max_num_partitions),
+                    dtype=torch.float32,
+                    device=output.device,
+                )
+                max_logits = torch.empty_like(exp_sums)
+                ipex_ops.paged_attention_v2(
+                    output,
+                    exp_sums,
+                    max_logits,
+                    tmp_output,
+                    query,
+                    key_cache,
+                    value_cache,
+                    self.num_kv_heads,
+                    self.scale,
+                    attn_metadata.block_tables,
+                    attn_metadata.seq_lens_tensor,
+                    block_size,
+                    max_seq_len,
+                    self.alibi_slopes,
+                    self.kv_cache_dtype,
+                    k_scale,
+                    v_scale,
+                )
+
+            # Reshape the output tensor.
+        return output.view(-1, self.num_heads * self.head_size)
+
+
+def _make_alibi_bias(
+    alibi_slopes: torch.Tensor,
+    dtype: torch.dtype,
+    seq_lens: List[int],
+) -> List[torch.Tensor]:
+    attn_biases = []
+    for seq_len in seq_lens:
+        bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
+        # NOTE: HF uses
+        #     `bias = bias[None, :].repeat(seq_len, 1)`
+        # here. We find that both biases give the same results, but
+        # the bias below more accurately follows the original ALiBi
+        # paper.
+        bias = bias[None, :] - bias[:, None]
+
+        num_heads = alibi_slopes.shape[0]
+        bias = bias[None, :].repeat((num_heads, 1, 1))
+        bias.mul_(alibi_slopes[:, None, None])
+        inf_mask = torch.empty(
+            (1, seq_len, seq_len),
+            dtype=bias.dtype,
+            device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
+        attn_biases.append((bias + inf_mask).to(dtype))
+
+    return attn_biases
+
+
+def _make_sliding_window_bias(
+    seq_lens: List[int],
+    window_size: Optional[int],
+    dtype: torch.dtype,
+) -> List[torch.Tensor]:
+    attn_biases = []
+    for seq_len in seq_lens:
+        tensor = torch.full(
+            (1, seq_len, seq_len),
+            dtype=dtype,
+            fill_value=1,
+        )
+        shift = 0
+        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
+        if window_size is not None:
+            mask = torch.triu(mask, diagonal=shift - window_size + 1)
+        mask = torch.log(mask)
+        attn_biases.append(mask.to(dtype))
+
+    return attn_biases

+ 100 - 0
aphrodite/attention/backends/openvino.py

@@ -0,0 +1,100 @@
+from dataclasses import dataclass
+from typing import List, Tuple
+
+import openvino as ov
+import torch
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionMetadata)
+
+
+class OpenVINOAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_name() -> str:
+        return "openvino"
+
+    @staticmethod
+    def get_impl_cls():
+        # OpenVINO implements PagedAttention as part of the Optimum
+        # exported model
+        raise NotImplementedError
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "AttentionMetadata":
+        raise NotImplementedError
+
+    @staticmethod
+    def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
+        return OpenVINOAttentionMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return (2, num_blocks, num_kv_heads, block_size, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: ov.Tensor,
+        dst_kv_cache: ov.Tensor,
+        src_to_dst: torch.Tensor,
+    ) -> None:
+        # OpenVINO currently supports only CPU, which does not require
+        # swap of KV cache blocks
+        raise NotImplementedError
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
+        src_to_dists: List[Tuple[int, int]],
+    ) -> None:
+        for src, dst in src_to_dists:
+            for key_cache, value_cache in kv_caches:
+                key_cache.data[dst, :] = key_cache.data[src, :]
+                value_cache.data[dst, :] = value_cache.data[src, :]
+
+
+@dataclass
+class OpenVINOAttentionMetadata:
+    """Metadata for OpenVINOAttentionBackend.
+    Basic terms used below:
+    - batch_size_in_sequences - total number of sequences to execute​
+    - prompt_lens – per sequence size number of scheduled tokens​
+    - batch_size_in_tokens = sum(prompt_lens)​
+    - max_context_len = max(context_lens)​
+    - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
+    - num_blocks – total number of blocks in block_indices​
+    """
+
+    # Describes past KV cache size for each sequence within a batch
+    # Shape: [batch_size_in_sequences]
+    # Type: i32​
+    past_lens: torch.Tensor
+
+    # Describes start indices of input / speculative tokens from
+    # current sequences within a batch sequence​
+    # Shape: [batch_size_in_sequences + 1]​
+    # Type: i32
+    subsequence_begins: torch.Tensor
+
+    # Describes block tables for each sequence within a batch​ -
+    # indices along 0th dimension in key_cache and value_cache inputs​
+    # Shape: [num_blocks]
+    # Type: i32​
+    block_indices: torch.Tensor
+
+    # Describes block tables for each sequence within a batch​ -
+    # for i-th element, it is an index in block_indices with the
+    # first block belonging to i-th sequence​
+    # Shape: [batch_size_in_sequences + 1]
+    # Type: i32​
+    block_indices_begins: torch.Tensor
+
+    # Describes max context length
+    # Shape: scalar
+    # Type: i32
+    max_context_len: torch.Tensor

+ 248 - 0
aphrodite/attention/backends/pallas.py

@@ -0,0 +1,248 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+import torch_xla.experimental.custom_kernel  # Required to register custom ops.
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+
+
+class PallasAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
+        return PallasAttentionBackendImpl
+
+    @staticmethod
+    def get_metadata_cls() -> Type["PallasMetadata"]:
+        return PallasMetadata
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return (num_kv_heads, num_blocks, block_size, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: torch.Tensor,
+    ) -> None:
+        raise RuntimeError("swap_blocks is not used for the TPU backend.")
+
+    @torch.compile(backend="openxla")
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
+        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
+    ) -> None:
+        src_indices, dst_indices = src_to_dists
+        for k_cache, v_cache in kv_caches:
+            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
+            k_cache[:, dst_indices] = k_cache[:, src_indices]
+            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
+            v_cache[:, dst_indices] = v_cache[:, src_indices]
+
+
+@dataclass
+class PallasMetadata(AttentionMetadata):
+
+    # Currently, input sequences can only contain all prefills
+    # or all decoding.
+    block_tables: Optional[torch.Tensor] = None
+    context_lens: Optional[torch.Tensor] = None
+
+    @property
+    def prefill_metadata(self) -> Optional["PallasMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        assert self.num_decode_tokens == 0
+        assert self.block_tables is None
+        assert self.context_lens is None
+        return self
+
+    @property
+    def decode_metadata(self) -> Optional["PallasMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        assert self.num_prefills == 0
+        assert self.num_prefill_tokens == 0
+        assert self.block_tables is not None
+        assert self.context_lens is not None
+        return self
+
+
+class PallasAttentionBackendImpl(AttentionImpl):
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        if head_size % 128 != 0:
+            raise NotImplementedError("Head size must be a multiple of 128.")
+        if alibi_slopes is not None:
+            raise NotImplementedError("Alibi slopes is not supported.")
+        if sliding_window is not None:
+            raise NotImplementedError("Sliding window is not supported.")
+        if kv_cache_dtype != "auto":
+            raise NotImplementedError("FP8 KV cache dtype is not supported.")
+        if blocksparse_params is not None:
+            raise NotImplementedError("Blocksparse is not supported.")
+        if logits_soft_cap is not None:
+            raise NotImplementedError(
+                "Attention logits soft-capping is not supported.")
+
+        if torch_xla.tpu.version() < 4:
+            raise NotImplementedError("TPU version must be 4 or higher.")
+
+        self.megacore_mode = None
+        tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
+        if "lite" not in tpu_type:
+            if self.num_kv_heads % 2 == 0:
+                self.megacore_mode = "kv_head"
+            else:
+                # NOTE: If the batch size is not a multiple of 2, the
+                # megacore mode will be None.
+                self.megacore_mode = "batch"
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
+        attn_metadata: PallasMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
+    ) -> torch.Tensor:
+        """Forward pass with Pallas attention.
+        Args:
+            query: shape = [batch_size, seq_len, num_heads * head_size]
+            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            key_cache = [num_kv_heads, num_blocks, block_size, head_size]
+            value_cache = [num_kv_heads, num_blocks, block_size, head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [batch_size, seq_len, num_heads * head_size]
+        """
+        assert k_scale == 1.0 and v_scale == 1.0
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "PallasAttentionBackendImpl")
+        batch_size, seq_len, hidden_size = query.shape
+        query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
+        key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
+        value = value.view(batch_size, seq_len, self.num_kv_heads,
+                           self.head_size)
+
+        if kv_cache[0] is not None:
+            slot_mapping = attn_metadata.slot_mapping
+            key_cache, value_cache = kv_cache
+            write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
+
+        query = query * self.scale
+        if attn_metadata.num_prefills > 0:
+            assert seq_len % 16 == 0, (
+                "Pallas FlashAttention kernel requires seq_len to be a "
+                f"multiple of 16 but got {seq_len}")
+
+            # Handle GQA/MQA.
+            if self.num_kv_heads != self.num_heads:
+                key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
+                key = key.view(batch_size, seq_len, self.num_heads,
+                               self.head_size)
+                value = value.repeat_interleave(self.num_queries_per_kv,
+                                                dim=-2)
+                value = value.view(batch_size, seq_len, self.num_heads,
+                                   self.head_size)
+            # FlashAttention requires [batch_size, num_heads, seq_len, d_model]
+            # while the input is [batch_size, seq_len, num_heads, d_model].
+            # Permute the input to match the required format.
+            output = torch.ops.xla.flash_attention(
+                query.permute(0, 2, 1, 3),
+                key.permute(0, 2, 1, 3),
+                value.permute(0, 2, 1, 3),
+                True,
+            )
+            output = output.permute(0, 2, 1, 3)
+        else:
+            # Decoding run.
+            assert kv_cache is not None
+
+            pages_per_compute_block = 16  # TODO: Tune this value.
+            if self.megacore_mode == "batch" and batch_size % 2 != 0:
+                megacore_mode = None
+            else:
+                megacore_mode = self.megacore_mode
+
+            # NOTE: A temporary workaround to avoid the error:
+            # "xla::paged_attention() Expected a value of type 'str' for
+            # argument 'megacore_mode' but instead found type 'NoneType'."
+            if megacore_mode is not None:
+                output = torch.ops.xla.paged_attention(
+                    query.squeeze(dim=1),
+                    key_cache,
+                    value_cache,
+                    attn_metadata.context_lens,
+                    attn_metadata.block_tables,
+                    pages_per_compute_block,
+                    megacore_mode=megacore_mode,
+                )
+            else:
+                output = torch.ops.xla.paged_attention(
+                    query.squeeze(dim=1),
+                    key_cache,
+                    value_cache,
+                    attn_metadata.context_lens,
+                    attn_metadata.block_tables,
+                    pages_per_compute_block,
+                )
+
+        # Reshape the output tensor.
+        return output.reshape(batch_size, seq_len, hidden_size)
+
+
+def write_to_kv_cache(
+    key: torch.Tensor,
+    value: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    slot_mapping: torch.Tensor,
+) -> None:
+    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
+    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
+
+    key = key.flatten(0, 2)
+    value = value.flatten(0, 2)
+    key_cache = key_cache.flatten(0, 2)
+    value_cache = value_cache.flatten(0, 2)
+    key_cache.index_copy_(0, slot_mapping, key)
+    value_cache.index_copy_(0, slot_mapping, value)

+ 296 - 142
aphrodite/attention/backends/rocm_flash_attn.py

@@ -1,32 +1,37 @@
 """Attention layer ROCm GPUs."""
 import os
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from loguru import logger
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class ROCmFlashAttentionBackend(AttentionBackend):
 
+    @staticmethod
+    def get_name() -> str:
+        return "rocm-flash-attn"
+
     @staticmethod
     def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
         return ROCmFlashAttentionImpl
 
     @staticmethod
-    def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
-        return ROCmFlashAttentionMetadata(*args, **kwargs)
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return ROCmFlashAttentionMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
+        return ROCmFlashAttentionMetadataBuilder
 
     @staticmethod
     def get_kv_cache_shape(
@@ -42,55 +47,53 @@ class ROCmFlashAttentionBackend(AttentionBackend):
     def swap_blocks(
         src_kv_cache: torch.Tensor,
         dst_kv_cache: torch.Tensor,
-        src_to_dst: Dict[int, int],
+        src_to_dst: torch.Tensor,
     ) -> None:
         PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
         PagedAttention.copy_blocks(kv_caches, src_to_dists)
 
 
 @dataclass
-class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
-                                 PagedAttentionMetadata):
+class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
     """Metadata for FlashAttentionBackend.
+
     NOTE: Any python object stored here is not updated when it is
     cuda-graph replayed. If you have values that need to be changed
     dynamically, it should be stored in tensor. The tensor has to be
     updated from `CUDAGraphRunner.forward` API.
     """
-    # Currently, input sequences can only contain all prompts
-    # or all decoding. True if all sequences are prompts.
-    is_prompt: bool
-    # (batch_size,). The prompt length per sequence. None if it is a decoding.
-    prompt_lens: Optional[List[int]]
-    # prompt_lens stored as a tensor.
-    prompt_lens_tensor: Optional[torch.Tensor]
-
-    # NOTE: Definition of context_len, subquery_len, and seqlen.
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE: Definition of context_len, query_len, and seq_len.
     # |---------- N-1 iteration --------|
     # |---------------- N iteration ---------------------|
     # |- tokenA -|......................|-- newTokens ---|
     # |---------- context_len ----------|
-    # |-------------------- seqlen ----------------------|
-    #                                   |- subquery_len -|
-
-    # WARNING: context_len has different definition depending on if it is
-    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
-    # When it is for decoding, it includes a new token.
-
-    # Maximum subquery length in the batch.
-    max_subquery_len: Optional[int]
-    # Maximum prompt length in the batch.
-    max_prompt_len: Optional[int]
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
     # (batch_size + 1,). The cumulative subquery lengths of the sequences in
     # the batch, used to index into subquery. E.g., if the subquery length
     # is [4, 6], it is [0, 4, 10].
-    subquery_start_loc: Optional[torch.Tensor]
+    query_start_loc: Optional[torch.Tensor]
     # (batch_size + 1,). The cumulative sequence lengths of the sequences in
     # the batch, used to index into sequence. E.g., if the sequence length is
     # [4, 6], it is [0, 4, 10].
@@ -100,6 +103,109 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
     # Cuda-graph is currently enabled for decoding only.
     # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
     use_cuda_graph: bool
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
+    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.block_tables is not None
+        assert self.seq_start_loc is not None
+
+        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=self.block_tables[:self.num_prefills],
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.block_tables is not None
+        assert self.seq_lens_tensor is not None
+
+        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self.block_tables[self.num_prefills:],
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class ROCmFlashAttentionMetadataBuilder(
+        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
+
+    _metadata_cls = ROCmFlashAttentionMetadata
+
+
+def _make_alibi_bias(alibi_slopes: torch.Tensor,
+                     dtype: torch.dtype,
+                     seq_lens: Optional[List[int]],
+                     make_attn_mask: bool = True) -> List[torch.Tensor]:
+    attn_biases = []
+    if seq_lens:
+        for seq_len in seq_lens:
+            bias = torch.arange(seq_len, dtype=dtype)
+            # NOTE(zhuohan): HF uses
+            #     `bias = bias[None, :].repeat(seq_len, 1)`
+            # here. We find that both biases give the same results, but
+            # the bias below more accurately follows the original ALiBi
+            # paper.
+            bias = bias[None, :] - bias[:, None]
+
+            num_heads = alibi_slopes.shape[0]
+            bias = bias[None, :].repeat(
+                (num_heads, 1, 1)).to(alibi_slopes.device)
+            bias.mul_(alibi_slopes[:, None, None])
+            if make_attn_mask:
+                inf_mask = torch.empty(
+                    (1, seq_len, seq_len),
+                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
+                        alibi_slopes.device)
+                attn_biases.append((bias + inf_mask).to(dtype))
+            else:
+                attn_biases.append(bias.to(dtype))
+
+    return attn_biases
 
 
 class ROCmFlashAttentionImpl(AttentionImpl):
@@ -107,18 +213,23 @@ class ROCmFlashAttentionImpl(AttentionImpl):
     If the input tensors contain prompt tokens, the layout is as follows:
     |<--------------- num_prompt_tokens -------------->|
     |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+
     Otherwise, the layout is as follows:
     |<------------------ num_generation_tokens (M) ----------------->|
     |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+
     Generation tokens can contain padding when cuda-graph is used.
     Currently, prompt tokens don't contain any padding.
+
     The prompts might have different lengths, while the generation tokens
     always have length 1.
 
     If chunked prefill is enabled, prefill tokens and decode tokens can be
     batched together in a flattened 1D query.
+
     |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
     |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
+
     Currently, cuda graph is disabled for chunked prefill, meaning there's no
     padding between prefill and decode tokens.
     """
@@ -128,49 +239,72 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         num_heads: int,
         head_size: int,
         scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
     ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "ROCmFlashAttention does not support blocksparse attention.")
+        if logits_soft_cap is not None:
+            raise ValueError(
+                "ROCmFlashAttention does not support attention logits soft "
+                "capping.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = ((sliding_window, sliding_window)
-                               if sliding_window is not None else (-1, -1))
+        self.num_kv_heads = num_kv_heads
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        supported_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in supported_head_sizes:
             raise ValueError(
                 f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Supported head sizes are: {supported_head_sizes}.")
 
-        self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
+        self.use_naive_attn = False
         # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
         self.use_triton_flash_attn = (os.environ.get(
             "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
                                       in ("true", "1"))
-        if self.use_naive_attn:
-            # AMD Radeon 7900 series (gfx1100) currently does not support
-            # xFormers nor FlashAttention. As a temporary workaround, we use
-            # naive PyTorch implementation of attention.
-            self.attn_fuc = _naive_attention
-            logger.debug("Using naive attention in ROCmBackend")
-        elif self.use_triton_flash_attn:
+        if self.use_triton_flash_attn:
             from aphrodite.attention.ops.triton_flash_attn import (  # noqa: F401
-                triton_attention, )
+                triton_attention)
             self.attn_func = triton_attention
             logger.debug("Using Triton FA in ROCmBackend")
+            if self.sliding_window != (-1, -1):
+                logger.warning("ROCm Triton FA does not currently support "
+                               "sliding window attention. If using half "
+                               "precision, please try using the ROCm CK "
+                               "FA backend instead by setting the env var "
+                               "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
         else:
-            from flash_attn import flash_attn_varlen_func  # noqa: F401
-            self.attn_func = flash_attn_varlen_func
-            logger.debug("Using CK FA in ROCmBackend")
+            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
+            # either
+            if torch.cuda.get_device_capability()[0] != 9:
+                self.use_naive_attn = True
+            else:
+                try:
+                    from flash_attn import flash_attn_varlen_func  # noqa: F401
+                    self.attn_func = flash_attn_varlen_func
+                    logger.debug("Using CK FA in ROCmBackend")
+                except ModuleNotFoundError:
+                    self.use_naive_attn = True
+
+            if self.use_naive_attn:
+                self.attn_func = _sdpa_attention
+                logger.debug("Using naive attention in ROCmBackend")
 
     def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
         """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -186,10 +320,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: torch.Tensor,
-        attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
-        kv_scale: float = 1.0,
+        attn_metadata: ROCmFlashAttentionMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
+
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
@@ -199,6 +336,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "ROCmFlashAttentionImpl")
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -218,8 +360,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 key_cache,
                 value_cache,
                 attn_metadata.slot_mapping,
-                attn_metadata.kv_cache_dtype,
-                kv_scale,
+                self.kv_cache_dtype,
+                k_scale,
+                v_scale,
             )
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
@@ -240,40 +383,59 @@ class ROCmFlashAttentionImpl(AttentionImpl):
 
         if prefill_meta := attn_metadata.prefill_metadata:
             # Prompt run.
+            assert prefill_meta.seq_lens is not None
             if kv_cache is None or prefill_meta.block_tables.numel() == 0:
                 # triton attention
                 # When block_tables are not filled, it means q and k are the
                 # prompt, and they have the same length.
-                if self.use_naive_attn or self.use_triton_flash_attn:
+                attn_masks = None
+                if self.use_triton_flash_attn:
+                    if self.alibi_slopes is not None:
+                        attn_masks = _make_alibi_bias(
+                            self.alibi_slopes,
+                            query.dtype,
+                            attn_metadata.seq_lens,
+                            make_attn_mask=False)  # type: ignore
+                    out, _ = self.attn_func(
+                        query,
+                        key,
+                        value,
+                        None,
+                        prefill_meta.seq_start_loc,
+                        prefill_meta.seq_start_loc,
+                        prefill_meta.max_prefill_seq_len,
+                        prefill_meta.max_prefill_seq_len,
+                        True,
+                        self.scale,
+                        attn_masks[0][None]
+                        if attn_masks is not None else None,
+                    )
+                elif self.use_naive_attn:
                     if self.num_kv_heads != self.num_heads:
                         # Interleave for MQA workaround.
                         key = self.repeat_kv(key, self.num_queries_per_kv)
                         value = self.repeat_kv(value, self.num_queries_per_kv)
-                    if self.use_naive_attn:
-                        out = self.attn_fuc(
-                            query,
-                            key,
-                            value,
-                            prefill_meta.prompt_lens,
-                            self.scale,
-                        )
-                        assert output[:num_prefill_tokens].shape == out.shape
-                        output[:num_prefill_tokens] = out
-                    else:
-                        out, _ = self.attn_func(
-                            query,
-                            key,
-                            value,
-                            None,
-                            prefill_meta.seq_start_loc,
-                            prefill_meta.seq_start_loc,
-                            prefill_meta.max_prompt_len,
-                            prefill_meta.max_prompt_len,
-                            True,
-                            self.scale,
-                        )
-                        assert output[:num_prefill_tokens].shape == out.shape
-                        output[:num_prefill_tokens] = out
+                    if self.alibi_slopes is not None:
+                        attn_masks = _make_alibi_bias(
+                            self.alibi_slopes,
+                            query.dtype,
+                            attn_metadata.seq_lens,
+                            make_attn_mask=True)  # type: ignore
+                    query = query.movedim(0, query.dim() - 2)
+                    key = key.movedim(0, key.dim() - 2)
+                    value = value.movedim(0, value.dim() - 2)
+                    # sdpa math backend attention
+                    out = self.attn_func(
+                        query,
+                        key,
+                        value,
+                        prefill_meta.seq_lens,
+                        num_tokens,
+                        self.num_heads,
+                        self.head_size,
+                        self.scale,
+                        attn_masks,
+                    )
                 else:
                     out = self.attn_func(
                         q=query,
@@ -281,13 +443,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                         v=value,
                         cu_seqlens_q=prefill_meta.seq_start_loc,
                         cu_seqlens_k=prefill_meta.seq_start_loc,
-                        max_seqlen_q=prefill_meta.max_prompt_len,
-                        max_seqlen_k=prefill_meta.max_prompt_len,
+                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
+                        max_seqlen_k=prefill_meta.max_prefill_seq_len,
                         softmax_scale=self.scale,
                         causal=True,
+                        window_size=self.sliding_window,
+                        alibi_slopes=self.alibi_slopes,
                     )
-                    assert output[:num_prefill_tokens].shape == out.shape
-                    output[:num_prefill_tokens] = out
+
+                # common code for prefill
+                assert output[:num_prefill_tokens].shape == out.shape
+                output[:num_prefill_tokens] = out
             else:
                 # prefix-enabled attention
                 output[:num_prefill_tokens] = PagedAttention.forward_prefix(
@@ -297,11 +463,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                     key_cache,
                     value_cache,
                     prefill_meta.block_tables,
-                    prefill_meta.subquery_start_loc,
-                    prefill_meta.prompt_lens_tensor,
-                    prefill_meta.context_lens,
-                    prefill_meta.max_subquery_len,
+                    prefill_meta.query_start_loc,
+                    prefill_meta.seq_lens_tensor,
+                    prefill_meta.context_lens_tensor,
+                    prefill_meta.max_query_len,
                     self.alibi_slopes,
+                    self.sliding_window[0],
                 )
 
         if decode_meta := attn_metadata.decode_metadata:
@@ -311,63 +478,50 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 key_cache,
                 value_cache,
                 decode_meta.block_tables,
-                decode_meta.context_lens,
-                decode_meta.max_context_len,
-                attn_metadata.kv_cache_dtype,
+                decode_meta.seq_lens_tensor,
+                decode_meta.max_decode_seq_len,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
+
         # Reshape the output tensor.
         return output.view(num_tokens, hidden_size)
 
 
-def _naive_attention(
+def _sdpa_attention(
     query: torch.Tensor,
     key: torch.Tensor,
     value: torch.Tensor,
-    prompt_lens: List[int],
+    seq_lens: List[int],
+    num_tokens: int,
+    num_heads: int,
+    head_size: int,
     scale: float,
+    attn_masks: Optional[List[torch.Tensor]] = None,
 ) -> torch.Tensor:
-    num_tokens = query.shape[0]  # noqa: F841
-    output = torch.empty_like(query)
     start = 0
-    for _, prompt_len in enumerate(prompt_lens):
-        end = start + prompt_len
-        out = _naive_masked_attention(
-            query[start:end],
-            key[start:end],
-            value[start:end],
-            scale,
-        )
-        # TODO: Unnecessary copy. Optimize.
-        output[start:end].copy_(out)
-        start += prompt_len
-
-    # Using view got RuntimeError: view size is not compatible
-    # with input tensor's size and stride (at least one
-    # dimension spans across two contiguous subspaces).
-    # Use reshape instead.
-    return output
+    output = torch.empty((num_tokens, num_heads, head_size),
+                         dtype=query.dtype,
+                         device=query.device)
+
+    for i, seq_len in enumerate(seq_lens):
+        end = start + seq_len
+        with torch.backends.cuda.sdp_kernel(enable_math=True,
+                                            enable_flash=False,
+                                            enable_mem_efficient=False):
+            sub_out = torch.nn.functional.scaled_dot_product_attention(
+                query[:, start:end, :],
+                key[:, start:end, :],
+                value[:, start:end, :],
+                dropout_p=0.0,
+                is_causal=attn_masks is None,
+                attn_mask=attn_masks[i] if attn_masks else None,
+                scale=scale).movedim(query.dim() - 2, 0)
+            output[start:end, :, :] = sub_out
+            start = end
 
-
-def _naive_masked_attention(
-    query: torch.Tensor,
-    key: torch.Tensor,
-    value: torch.Tensor,
-    scale: float,
-) -> torch.Tensor:
-    seq_len, head_size, head_dim = query.shape
-    attn_mask = torch.triu(torch.ones(seq_len,
-                                      seq_len,
-                                      dtype=query.dtype,
-                                      device=query.device),
-                           diagonal=1)
-    attn_mask = attn_mask * torch.finfo(query.dtype).min
-
-    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
-    attn_weights = attn_weights + attn_mask.float()
-    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
-    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
-    return out
+    return output

+ 117 - 57
aphrodite/attention/backends/sdpa.py → aphrodite/attention/backends/torch_sdpa.py

@@ -1,32 +1,45 @@
 """ Attention layer with torch scaled_dot_product_attention
     and PagedAttention."""
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from torch.nn.functional import scaled_dot_product_attention
 
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.ops.paged_attn import PagedAttentionMetadata
+from aphrodite.common.utils import is_cpu
+
+if is_cpu():
+    try:
+        from aphrodite.attention.ops.ipex_attn import PagedAttention
+    except ImportError:
+        from aphrodite.attention.ops.paged_attn import PagedAttention
+else:
+    from aphrodite.attention.ops.paged_attn import PagedAttention
 
 
 class TorchSDPABackend(AttentionBackend):
 
+    @staticmethod
+    def get_name() -> str:
+        return "torch-sdpa"
+
     @staticmethod
     def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
         return TorchSDPABackendImpl
 
     @staticmethod
-    def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
-        return TorchSDPAMetadata(*args, **kwargs)
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return TorchSDPAMetadata
+    
+    @staticmethod
+    def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
+        return TorchSDPAMetadataBuilder
 
     @staticmethod
     def get_kv_cache_shape(
@@ -42,28 +55,27 @@ class TorchSDPABackend(AttentionBackend):
     def swap_blocks(
         src_kv_cache: torch.Tensor,
         dst_kv_cache: torch.Tensor,
-        src_to_dst: Dict[int, int],
+        src_to_dst: torch.Tensor,
     ) -> None:
         PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
         PagedAttention.copy_blocks(kv_caches, src_to_dists)
 
 
 @dataclass
-class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
-                        AttentionMetadataPerStage):
+class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
     """Metadata for TorchSDPABackend.
     """
     # Currently, input sequences can only contain all prompts
     # or all decoding. True if all sequences are prompts.
     is_prompt: bool
     slot_mapping: torch.Tensor
-    prompt_lens: Optional[List[int]]
+    seq_lens: Optional[List[int]]
 
     def __post_init__(self):
         # Set during the execution of the first attention op.
@@ -73,37 +85,74 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
         # will not appear in the __repr__ and __init__
         self.attn_bias: Optional[List[torch.Tensor]] = None
 
+    @property
+    def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_decode_tokens == 0:
+            assert self.num_prefills > 0
+            return self
+
+        return None
+
+    @property
+    def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
+        # Currently chunked prefill is not supported
+        if self.num_prefills > 0:
+            assert self.num_decode_tokens == 0
+            return None
+
+        return self
+
+
+class TorchSDPAMetadataBuilder(CommonMetadataBuilder[TorchSDPAMetadata]):
+
+    _metadata_cls = TorchSDPAMetadata
+
 
-class TorchSDPABackendImpl(AttentionImpl):
+
+class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
 
     def __init__(
         self,
         num_heads: int,
         head_size: int,
         scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
     ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "Torch SPDA does not support block-sparse attention.")
+        if logits_soft_cap is not None:
+            raise ValueError("Torch SPDA does not support logits soft cap.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
+        self.num_kv_heads = num_kv_heads
         if alibi_slopes is not None:
-            assert len(alibi_slopes) == num_heads
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
-        self.need_mask = (self.alibi_slopes is not None
-                          or self.sliding_window is not None)
+        self.sliding_window = sliding_window
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        self.need_mask = (self.alibi_slopes is not None
+                          or self.sliding_window is not None)
+
+        supported_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in supported_head_sizes:
             raise ValueError(
                 f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Supported head sizes are: {supported_head_sizes}.")
+        if kv_cache_dtype != "auto":
+            raise NotImplementedError(
+                "Torch SDPA backend does not support FP8 KV cache. "
+                "Please use xFormers backend instead.")
 
     def forward(
         self,
@@ -111,8 +160,10 @@ class TorchSDPABackendImpl(AttentionImpl):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
-        attn_metadata: TorchSDPAMetadata,
-        kv_scale: float,
+        attn_metadata: TorchSDPAMetadata,  # type: ignore
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with torch SDPA and PagedAttention.
 
@@ -125,6 +176,12 @@ class TorchSDPABackendImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "TorchSDPABackendImpl")
+        assert k_scale == 1.0 and v_scale == 1.0
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -137,10 +194,11 @@ class TorchSDPABackendImpl(AttentionImpl):
             PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                 value_cache,
                                                 attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
-                                                kv_scale)
+                                                self.kv_cache_dtype, k_scale,
+                                                v_scale)
 
         if attn_metadata.is_prompt:
+            assert attn_metadata.seq_lens is not None
             if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
                 if self.num_kv_heads != self.num_heads:
                     key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
@@ -151,13 +209,13 @@ class TorchSDPABackendImpl(AttentionImpl):
                     if self.alibi_slopes is not None:
                         att_masks = _make_alibi_bias(
                             self.alibi_slopes, query.dtype,
-                            attn_metadata.prompt_lens)  # type: ignore
+                            attn_metadata.seq_lens)  # type: ignore
                     elif self.sliding_window is not None:
                         att_masks = _make_sliding_window_bias(
-                            attn_metadata.prompt_lens, self.sliding_window,
+                            attn_metadata.seq_lens, self.sliding_window,
                             query.dtype)  # type: ignore
                     else:
-                        att_masks = [None] * len(attn_metadata.prompt_lens)
+                        att_masks = [None] * len(attn_metadata.seq_lens)
                     attn_metadata.attn_bias = att_masks
 
                 query = query.movedim(0, query.dim() - 2)
@@ -168,17 +226,18 @@ class TorchSDPABackendImpl(AttentionImpl):
                 output = torch.empty(
                     (num_tokens, self.num_heads, self.head_size),
                     dtype=query.dtype)
-                for prompt_len, mask in zip(attn_metadata.prompt_lens,
-                                            attn_metadata.attn_bias):
-                    end = start + prompt_len
+                for seq_len, mask in zip(attn_metadata.seq_lens,
+                                         attn_metadata.attn_bias):
+                    end = start + seq_len
                     sub_out = scaled_dot_product_attention(
-                        query[:, start:end, :],
-                        key[:, start:end, :],
-                        value[:, start:end, :],
+                        query[None, :, start:end, :],
+                        key[None, :, start:end, :],
+                        value[None, :, start:end, :],
                         attn_mask=mask,
                         dropout_p=0.0,
                         is_causal=not self.need_mask,
-                        scale=self.scale).movedim(query.dim() - 2, 0)
+                        scale=self.scale).squeeze(0).movedim(
+                            query.dim() - 2, 0)
                     output[start:end, :, :] = sub_out
                     start = end
             else:
@@ -193,13 +252,14 @@ class TorchSDPABackendImpl(AttentionImpl):
                 key_cache,
                 value_cache,
                 attn_metadata.block_tables,
-                attn_metadata.context_lens,
-                attn_metadata.max_context_len,
-                attn_metadata.kv_cache_dtype,
+                attn_metadata.seq_lens_tensor,
+                attn_metadata.max_decode_seq_len,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         # Reshape the output tensor.
@@ -209,13 +269,13 @@ class TorchSDPABackendImpl(AttentionImpl):
 def _make_alibi_bias(
     alibi_slopes: torch.Tensor,
     dtype: torch.dtype,
-    prompt_lens: List[int],
+    seq_lens: List[int],
 ) -> List[torch.Tensor]:
     attn_biases = []
-    for prompt_len in prompt_lens:
-        bias = torch.arange(prompt_len, dtype=dtype)
+    for seq_len in seq_lens:
+        bias = torch.arange(seq_len, dtype=dtype)
         # NOTE: HF uses
-        #     `bias = bias[None, :].repeat(prompt_len, 1)`
+        #     `bias = bias[None, :].repeat(seq_len, 1)`
         # here. We find that both biases give the same results, but
         # the bias below more accurately follows the original ALiBi
         # paper.
@@ -223,9 +283,9 @@ def _make_alibi_bias(
 
         num_heads = alibi_slopes.shape[0]
         bias = bias[None, :].repeat((num_heads, 1, 1))
-        bias.mul_(alibi_slopes[:, None, None])
+        bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
         inf_mask = torch.empty(
-            (1, prompt_len, prompt_len),
+            (1, seq_len, seq_len),
             dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
         attn_biases.append((bias + inf_mask).to(dtype))
 
@@ -233,14 +293,14 @@ def _make_alibi_bias(
 
 
 def _make_sliding_window_bias(
-    prompt_lens: List[int],
+    seq_lens: List[int],
     window_size: Optional[int],
     dtype: torch.dtype,
 ) -> List[torch.Tensor]:
     attn_biases = []
-    for prompt_len in prompt_lens:
+    for seq_len in seq_lens:
         tensor = torch.full(
-            (1, prompt_len, prompt_len),
+            (1, seq_len, seq_len),
             dtype=dtype,
             fill_value=1,
         )

+ 232 - 0
aphrodite/attention/backends/utils.py

@@ -0,0 +1,232 @@
+from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
+
+import torch
+
+from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
+from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
+
+# Error string(s) for encoder/decoder
+# unsupported attention scenarios
+STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
+                                 "with encoder/decoder models.")
+
+PAD_SLOT_ID = -1
+
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+
+
+def is_block_tables_empty(block_tables: Union[None, Dict]):
+    """
+    Check if block_tables is None or a dictionary with all None values.
+    """
+    if block_tables is None:
+        return True
+    if isinstance(block_tables, dict) and all(
+            value is None for value in block_tables.values()):
+        return True
+    return False
+
+
+def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
+                                   context_len: int, sliding_window: int,
+                                   use_v2_block_manager: bool):
+    """
+    Compute the start index of slot mapping.
+    """
+    start_idx = 0
+    if is_prompt and sliding_window is not None:
+        assert use_v2_block_manager or context_len == 0, (
+            "Prefix caching is currently not supported with "
+            "sliding window attention in V1 block manager")
+        # When prefill, we use it to not write slots to kv cache
+        # to save memory.
+        start_idx = max(0, query_len - sliding_window)
+    return start_idx
+
+
+def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
+                         seq_id: int, seq_len: int, context_len: int,
+                         start_idx: int, block_size: int,
+                         block_tables: Dict[int, List[int]]):
+    """
+    Compute slot mapping.
+    """
+    if is_profile_run:
+        # During memory profiling, the block tables are not
+        # initialized yet. In this case, we just use a dummy
+        # slot mapping.
+        # In embeddings, the block tables are {seq_id: None}.
+        slot_mapping.extend([PAD_SLOT_ID] * seq_len)
+        return
+
+    # Mask the [0, start_idx) tokens of the prompt with
+    # PAD_SLOT_ID, where start_idx is max(0, seq_len -
+    # sliding_window). For example, if the prompt len is 10,
+    # sliding window is 8, and block size is 4, the first two
+    # tokens are masked and the slot mapping will be
+    # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
+    block_table = block_tables[seq_id]
+    slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
+    for i in range(max(start_idx, context_len), seq_len):
+        block_number = block_table[i // block_size]
+        block_offset = i % block_size
+        slot = block_number * block_size + block_offset
+        slot_mapping.append(slot)
+
+
+TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
+
+
+class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
+
+    _metadata_cls: Type[TAttentionMetadata]
+
+    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
+        self.slot_mapping: List[int] = []
+        self.prefill_seq_lens: List[int] = []
+        self.context_lens: List[int] = []
+        self.block_tables: List[List[int]] = []
+        self.curr_seq_lens: List[int] = []
+        self.num_prefills = 0
+        self.num_prefill_tokens = 0
+        self.num_decode_tokens = 0
+
+        self.input_builder = input_builder
+        self.runner = input_builder.runner
+
+        self.sliding_window = input_builder.sliding_window
+        self.block_size = input_builder.block_size
+        self.use_v2_block_manager = (
+            input_builder.scheduler_config.use_v2_block_manager)
+
+    def _add_seq_group(
+            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
+            chunked_prefill_enabled: bool):
+        is_prompt = inter_data.is_prompt
+        block_tables = inter_data.block_tables
+        computed_block_nums = inter_data.computed_block_nums
+
+        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
+             curr_sliding_window_block) in zip(
+                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
+                 inter_data.orig_seq_lens, inter_data.seq_lens,
+                 inter_data.query_lens, inter_data.context_lens,
+                 inter_data.curr_sliding_window_blocks):
+            self.context_lens.append(context_len)
+            if is_prompt:
+                self.num_prefills += 1
+                self.num_prefill_tokens += token_len
+                self.prefill_seq_lens.append(seq_len)
+            else:
+                assert query_len == 1, (
+                    "seq_len: {}, context_len: {}, query_len: {}".format(
+                        seq_len, context_len, query_len))
+                self.num_decode_tokens += query_len
+                self.curr_seq_lens.append(curr_seq_len)
+
+            # Compute block table.
+            # TODO: Combine chunked prefill and prefix caching by
+            # only allowing multiple of block_size chunk size.
+            # NOTE: This only works for oooooooxxx style attention.
+            block_table = []
+            if inter_data.prefix_cache_hit:
+                block_table = computed_block_nums
+            elif ((chunked_prefill_enabled or not is_prompt)
+                  and block_tables is not None):
+                block_table = block_tables[seq_id][-curr_sliding_window_block:]
+            self.block_tables.append(block_table)
+
+            # Compute slot mapping.
+            is_profile_run = is_block_tables_empty(block_tables)
+            start_idx = compute_slot_mapping_start_idx(
+                is_prompt, query_len, context_len, self.sliding_window,
+                self.use_v2_block_manager)
+            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
+                                 seq_len, context_len, start_idx,
+                                 self.block_size, inter_data.block_tables)
+
+    def build(self, seq_lens: List[int], query_lens: List[int],
+              cuda_graph_pad_size: int, batch_size: int):
+        """Build attention metadata with on-device tensors.
+        Args:
+            seq_lens: The maybe padded sequence lengths of the input sequences.
+            query_lens: The query lengths of the input sequences.
+            cuda_graph_pad_size: The padding size for cuda graph.
+                                 -1 if cuda graph is not used.
+            batch_size: The maybe padded batch size.
+        """
+        for inter_data in self.input_builder.inter_data_list:
+            self._add_seq_group(inter_data,
+                                self.input_builder.chunked_prefill_enabled)
+
+        device = self.runner.device
+        use_captured_graph = cuda_graph_pad_size != -1
+
+        max_query_len = max(query_lens)
+        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
+        max_decode_seq_len = max(self.curr_seq_lens, default=0)
+        num_decode_tokens = self.num_decode_tokens
+
+        if use_captured_graph:
+            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
+            self.block_tables.extend([] * cuda_graph_pad_size)
+            num_decode_tokens = batch_size
+
+            # The shape of graph_block_tables is
+            # [max batch size, max context len // block size].
+            input_block_tables = self.runner.graph_block_tables[:batch_size]
+            for i, block_table in enumerate(self.block_tables):
+                if block_table:
+                    input_block_tables[i, :len(block_table)] = block_table
+            block_tables = torch.from_numpy(input_block_tables).to(
+                device, non_blocking=True)
+        else:
+            block_tables = make_tensor_with_pad(
+                self.block_tables,
+                pad=0,
+                dtype=torch.int,
+                device=device,
+            )
+        assert max_query_len > 0, "query_lens: {}".format(query_lens)
+
+        assert device is not None
+        context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
+                                               device, self.runner.pin_memory)
+        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
+                                           self.runner.pin_memory)
+        query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
+                                             self.runner.pin_memory)
+        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
+                                               device, self.runner.pin_memory)
+        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
+                                      dtype=torch.int32,
+                                      device=device)
+        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
+                                    dtype=torch.int32,
+                                    device=device)
+        torch.cumsum(seq_lens_tensor,
+                     dim=0,
+                     dtype=seq_start_loc.dtype,
+                     out=seq_start_loc[1:])
+        torch.cumsum(query_lens_tensor,
+                     dim=0,
+                     dtype=query_start_loc.dtype,
+                     out=query_start_loc[1:])
+
+        return self._metadata_cls(  # type: ignore
+            num_prefills=self.num_prefills,
+            slot_mapping=slot_mapping_tensor,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=num_decode_tokens,
+            seq_lens=seq_lens,
+            seq_lens_tensor=seq_lens_tensor,
+            max_query_len=max_query_len,
+            max_prefill_seq_len=max_prefill_seq_len,
+            max_decode_seq_len=max_decode_seq_len,
+            query_start_loc=query_start_loc,
+            seq_start_loc=seq_start_loc,
+            context_lens_tensor=context_lens_tensor,
+            block_tables=block_tables,
+            use_cuda_graph=use_captured_graph,
+        )

+ 509 - 111
aphrodite/attention/backends/xformers.py

@@ -1,36 +1,40 @@
 """Attention layer with xFormers and PagedAttention."""
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (
-    AttentionBias,
-    BlockDiagonalCausalMask,
-    LowerTriangularMaskWithTensorBias,
-)
-
-from aphrodite.attention.backends.abstract import (
-    AttentionBackend,
-    AttentionImpl,
-    AttentionMetadata,
-    AttentionMetadataPerStage,
-)
-from aphrodite.attention.ops.paged_attn import (
-    PagedAttention,
-    PagedAttentionMetadata,
-)
+from xformers.ops.fmha.attn_bias import (AttentionBias,
+                                         BlockDiagonalCausalMask,
+                                         BlockDiagonalMask,
+                                         LowerTriangularMaskWithTensorBias)
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionType)
+from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.ops.paged_attn import (PagedAttention,
+                                                PagedAttentionMetadata)
 
 
 class XFormersBackend(AttentionBackend):
 
+    @staticmethod
+    def get_name() -> str:
+        return "xformers"
+
     @staticmethod
     def get_impl_cls() -> Type["XFormersImpl"]:
         return XFormersImpl
 
     @staticmethod
-    def make_metadata(*args, **kwargs) -> "XFormersMetadata":
-        return XFormersMetadata(*args, **kwargs)
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return XFormersMetadata
+
+    @staticmethod
+    def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
+        return XFormersMetadataBuilder
 
     @staticmethod
     def get_kv_cache_shape(
@@ -53,13 +57,13 @@ class XFormersBackend(AttentionBackend):
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
         PagedAttention.copy_blocks(kv_caches, src_to_dists)
 
 
 @dataclass
-class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
+class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
     """Metadata for XFormersbackend.
 
     NOTE: Any python object stored here is not updated when it is
@@ -67,46 +71,73 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
     dynamically, it should be stored in tensor. The tensor has to be
     updated from `CUDAGraphRunner.forward` API.
     """
-    # Currently, input sequences can only contain all prompts
-    # or all decoding. True if all sequences are prompts.
-    is_prompt: bool
-    # (batch_size,). The prompt length per sequence. None if it is a decoding.
-    prompt_lens: Optional[List[int]]
-    # prompt_lens stored as a tensor.
-    prompt_lens_tensor: Optional[torch.Tensor]
-
-    # NOTE: Definition of context_len, subquery_len, and seqlen.
+
     # |---------- N-1 iteration --------|
     # |---------------- N iteration ---------------------|
     # |- tokenA -|......................|-- newTokens ---|
     # |---------- context_len ----------|
-    # |-------------------- seqlen ----------------------|
-    #                                   |- subquery_len -|
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
 
-    # WARNING(sang): context_len has different definition depending on if it is
-    # prefill vs decoding. When it is prefill, it doesn't include new tokens.
-    # When it is for decoding, it includes a new token.
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
 
-    # Maximum subquery length in the batch.
-    max_subquery_len: Optional[int]
     # FIXME: It is for flash attn.
-    # Maximum prompt length in the batch.
-    max_prompt_len: Optional[int]
-    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
-    # the batch, used to index into subquery. E.g., if the subquery length
-    # is [4, 6], it is [0, 4, 10].
-    subquery_start_loc: Optional[torch.Tensor]
-    # FIXME: It is for flash attn.
-    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
-    # the batch, used to index into sequence. E.g., if the sequence length is
-    # [4, 6], it is [0, 4, 10].
-    seq_start_loc: Optional[torch.Tensor]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
 
     # Whether or not if cuda graph is enabled.
     # Cuda-graph is currently enabled for decoding only.
     # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
     use_cuda_graph: bool
 
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]] = None
+
+    # FIXME: It is for flash attn.
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor] = None
+
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor] = None
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int] = None
+
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor] = None
+
+    # Self-attention prefill/decode metadata cache
+    _cached_prefill_metadata: Optional["XFormersMetadata"] = None
+    _cached_decode_metadata: Optional["XFormersMetadata"] = None
+
+    # Begin encoder attn & enc/dec cross-attn fields...
+
+    # Encoder sequence lengths representation
+    encoder_seq_lens: Optional[List[int]] = None
+    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
+
+    # Maximum sequence length among encoder sequences
+    max_encoder_seq_len: Optional[int] = None
+
+    # Number of tokens input to encoder
+    num_encoder_tokens: Optional[int] = None
+
+    # Cross-attention memory-mapping data structures: slot mapping
+    # and block tables
+    cross_slot_mapping: Optional[torch.Tensor] = None
+    cross_block_tables: Optional[torch.Tensor] = None
+
     def __post_init__(self):
         # Set during the execution of the first attention op.
         # It is a list because it is needed to set per prompt
@@ -114,9 +145,233 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
         # from xformer API.
         # will not appear in the __repr__ and __init__
         self.attn_bias: Optional[List[AttentionBias]] = None
-
-
-class XFormersImpl(AttentionImpl):
+        self.encoder_attn_bias: Optional[List[AttentionBias]] = None
+        self.cross_attn_bias: Optional[List[AttentionBias]] = None
+
+    @property
+    def is_all_encoder_attn_metadata_set(self):
+        '''
+        All attention metadata required for encoder attention is set.
+        '''
+        return ((self.encoder_seq_lens is not None)
+                and (self.encoder_seq_lens_tensor is not None)
+                and (self.max_encoder_seq_len is not None))
+
+    @property
+    def is_all_cross_attn_metadata_set(self):
+        '''
+        All attention metadata required for enc/dec cross-attention is set.
+
+        Superset of encoder attention required metadata.
+        '''
+        return (self.is_all_encoder_attn_metadata_set
+                and (self.cross_slot_mapping is not None)
+                and (self.cross_block_tables is not None))
+
+    @property
+    def prefill_metadata(self) -> Optional["XFormersMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            # Recover cached prefill-phase attention
+            # metadata structure
+            return self._cached_prefill_metadata
+
+        assert ((self.seq_lens is not None)
+                or (self.encoder_seq_lens is not None))
+        assert ((self.seq_lens_tensor is not None)
+                or (self.encoder_seq_lens_tensor is not None))
+
+        # Compute some attn_metadata fields which default to None
+        query_start_loc = (None if self.query_start_loc is None else
+                           self.query_start_loc[:self.num_prefills + 1])
+        slot_mapping = (None if self.slot_mapping is None else
+                        self.slot_mapping[:self.num_prefill_tokens])
+        seq_lens = (None if self.seq_lens is None else
+                    self.seq_lens[:self.num_prefills])
+        seq_lens_tensor = (None if self.seq_lens_tensor is None else
+                           self.seq_lens_tensor[:self.num_prefills])
+        context_lens_tensor = (None if self.context_lens_tensor is None else
+                               self.context_lens_tensor[:self.num_prefills])
+        block_tables = (None if self.block_tables is None else
+                        self.block_tables[:self.num_prefills])
+
+        # Construct & cache prefill-phase attention metadata structure
+        self._cached_prefill_metadata = XFormersMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=slot_mapping,
+            seq_lens=seq_lens,
+            seq_lens_tensor=seq_lens_tensor,
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=query_start_loc,
+            context_lens_tensor=context_lens_tensor,
+            block_tables=block_tables,
+            use_cuda_graph=False,
+            # Begin encoder & cross attn fields below...
+            encoder_seq_lens=self.encoder_seq_lens,
+            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
+            max_encoder_seq_len=self.max_encoder_seq_len,
+            cross_slot_mapping=self.cross_slot_mapping,
+            cross_block_tables=self.cross_block_tables)
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["XFormersMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            # Recover cached decode-phase attention
+            # metadata structure
+            return self._cached_decode_metadata
+        assert ((self.seq_lens_tensor is not None)
+                or (self.encoder_seq_lens_tensor is not None))
+
+        # Compute some attn_metadata fields which default to None
+        slot_mapping = (None if self.slot_mapping is None else
+                        self.slot_mapping[self.num_prefill_tokens:])
+        seq_lens_tensor = (None if self.seq_lens_tensor is None else
+                           self.seq_lens_tensor[self.num_prefills:])
+        block_tables = (None if self.block_tables is None else
+                        self.block_tables[self.num_prefills:])
+
+        # Construct & cache decode-phase attention metadata structure
+        self._cached_decode_metadata = XFormersMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=slot_mapping,
+            seq_lens_tensor=seq_lens_tensor,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            block_tables=block_tables,
+            use_cuda_graph=self.use_cuda_graph,
+            # Begin encoder & cross attn fields below...
+            encoder_seq_lens=self.encoder_seq_lens,
+            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
+            max_encoder_seq_len=self.max_encoder_seq_len,
+            cross_slot_mapping=self.cross_slot_mapping,
+            cross_block_tables=self.cross_block_tables)
+        return self._cached_decode_metadata
+
+
+def _get_attn_bias(
+    attn_metadata: XFormersMetadata,
+    attn_type: AttentionType,
+) -> Optional[AttentionBias]:
+    '''
+    Extract appropriate attention bias from attention metadata
+    according to attention type.
+
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+
+    Returns:
+    * Appropriate attention bias value given the attention type
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        return attn_metadata.attn_bias
+    elif attn_type == AttentionType.ENCODER:
+        return attn_metadata.encoder_attn_bias
+    else:
+        # attn_type == AttentionType.ENCODER_DECODER
+        return attn_metadata.cross_attn_bias
+
+
+def _set_attn_bias(
+    attn_metadata: XFormersMetadata,
+    attn_bias: List[Optional[AttentionBias]],
+    attn_type: AttentionType,
+) -> None:
+    '''
+    Update appropriate attention bias field of attention metadata,
+    according to attention type.
+
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention
+    * attn_bias: The desired attention bias value
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        attn_metadata.attn_bias = attn_bias
+    elif attn_type == AttentionType.ENCODER:
+        attn_metadata.encoder_attn_bias = attn_bias
+    elif attn_type == AttentionType.ENCODER_DECODER:
+        attn_metadata.cross_attn_bias = attn_bias
+    else:
+        raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+
+def _get_seq_len_block_table_args(
+    attn_metadata: XFormersMetadata,
+    is_prompt: bool,
+    attn_type: AttentionType,
+) -> tuple:
+    '''
+    The particular choice of sequence-length- and block-table-related
+    attributes which should be extracted from attn_metadata is dependent
+    on the type of attention operation.
+
+    Decoder attn -> select entirely decoder self-attention-related fields
+    Encoder/decoder cross-attn -> select encoder sequence lengths & 
+                                  cross-attn block-tables fields
+    Encoder attn -> select encoder sequence lengths fields & no block tables
+    
+    Arguments:
+
+    * attn_metadata: Attention metadata structure associated with attention op
+    * is_prompt: True if prefill, False otherwise
+    * attn_type: encoder attention, decoder self-attention,
+                 encoder/decoder cross-attention
+
+    Returns:
+
+    * Appropriate sequence-lengths tensor
+    * Appropriate max sequence-length scalar
+    * Appropriate block tables (or None)
+    '''
+
+    if attn_type == AttentionType.DECODER:
+        # Decoder self-attention
+        # Choose max_seq_len based on whether we are in prompt_run
+        if is_prompt:
+            max_seq_len = attn_metadata.max_prefill_seq_len
+        else:
+            max_seq_len = attn_metadata.max_decode_seq_len
+        return (attn_metadata.seq_lens_tensor, max_seq_len,
+                attn_metadata.block_tables)
+    elif attn_type == AttentionType.ENCODER_DECODER:
+        # Enc/dec cross-attention KVs match encoder sequence length;
+        # cross-attention utilizes special "cross" block tables
+        return (attn_metadata.encoder_seq_lens_tensor,
+                attn_metadata.max_encoder_seq_len,
+                attn_metadata.cross_block_tables)
+    elif attn_type == AttentionType.ENCODER:
+        # No block tables associated with encoder attention
+        return (attn_metadata.encoder_seq_lens_tensor,
+                attn_metadata.max_encoder_seq_len, None)
+    else:
+        raise AttributeError(f"Invalid attention type {str(attn_type)}")
+
+
+class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
+
+    _metadata_cls = XFormersMetadata
+
+
+class XFormersImpl(AttentionImpl[XFormersMetadata]):
     """
     If the input tensors contain prompt tokens, the layout is as follows:
     |<--------------- num_prefill_tokens ----------------->|	
@@ -147,18 +402,28 @@ class XFormersImpl(AttentionImpl):
         num_heads: int,
         head_size: int,
         scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
     ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "XFormers does not support block-sparse attention.")
+        if logits_soft_cap is not None:
+            raise ValueError(
+                "XFormers does not support attention logits soft capping.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
+        self.num_kv_heads = num_kv_heads
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = sliding_window
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -172,53 +437,145 @@ class XFormersImpl(AttentionImpl):
     def forward(
         self,
         query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
+        key: Optional[torch.Tensor],
+        value: Optional[torch.Tensor],
         kv_cache: Optional[torch.Tensor],
-        attn_metadata: AttentionMetadata[XFormersMetadata],
-        kv_scale: float,
+        attn_metadata: "XFormersMetadata",
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Forward pass with xFormers and PagedAttention.
 
+        For decoder-only models: query, key and value must be non-None.
+
+        For encoder/decoder models:
+        * XFormersImpl.forward() may be invoked for both self- and cross-
+          attention layers.
+        * For self-attention: query, key and value must be non-None.
+        * For cross-attention:
+            * Query must be non-None
+            * During prefill, key and value must be non-None; key and value
+              get cached for use during decode.
+            * During decode, key and value may be None, since:
+              (1) key and value tensors were cached during prefill, and
+              (2) cross-attention key and value tensors do not grow during
+                  decode
+        
+        A note on how the attn_type (attention type enum) argument impacts
+        attention forward() behavior:
+    
+            * DECODER: normal decoder-only behavior;
+                use decoder self-attention block table
+            * ENCODER: no KV caching; pass encoder sequence
+                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
+                max_encoder_seq_len) to kernel, in lieu of decoder
+                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
+            * ENCODER_DECODER: cross-attention behavior;
+                use cross-attention block table for caching KVs derived
+                from encoder hidden states; since KV sequence lengths
+                will match encoder sequence lengths, pass encoder sequence
+                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
+                max_encoder_seq_len)
+    
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
             value: shape = [num_tokens, num_kv_heads * head_size]
             kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
             attn_metadata: Metadata for attention.
+            attn_type: Select attention type, between encoder attention,
+                       decoder self-attention, or encoder/decoder cross-
+                       attention. Defaults to decoder self-attention,
+                       which is the Aphrodite default generally
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
-        num_tokens, hidden_size = query.shape
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
 
-        if kv_cache is not None:
+        # Check that appropriate attention metadata attributes are
+        # selected for the desired attention type
+        if (attn_type == AttentionType.ENCODER
+                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
+            raise AttributeError("Encoder attention requires setting "
+                                 "encoder metadata attributes.")
+        elif (attn_type == AttentionType.ENCODER_DECODER
+              and (not attn_metadata.is_all_cross_attn_metadata_set)):
+            raise AttributeError("Encoder/decoder cross-attention "
+                                 "requires setting cross-attention "
+                                 "metadata attributes.")
+
+        query = query.view(-1, self.num_heads, self.head_size)
+        if key is not None:
+            assert value is not None
+            key = key.view(-1, self.num_kv_heads, self.head_size)
+            value = value.view(-1, self.num_kv_heads, self.head_size)
+        else:
+            assert value is None
+
+        # Self-attention vs. cross-attention will impact
+        # which KV cache memory-mapping & which
+        # seqlen datastructures we utilize
+
+        if (attn_type != AttentionType.ENCODER and kv_cache is not None):
+            # KV-cache during decoder-self- or
+            # encoder-decoder-cross-attention, but not
+            # during encoder attention.
+            #
+            # Even if there are no new key/value pairs to cache,
+            # we still need to break out key_cache and value_cache
+            # i.e. for later use by paged attention
             key_cache, value_cache = PagedAttention.split_kv_cache(
                 kv_cache, self.num_kv_heads, self.head_size)
 
-            # Reshape the input keys and values and store them in the cache.
-            # If kv_cache is not provided, the new key and value tensors are
-            # not cached. This happens during the initial memory profiling run.
-            PagedAttention.write_to_paged_cache(key, value, key_cache,
-                                                value_cache,
-                                                attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
-                                                kv_scale)
-
-        num_prefill_tokens = attn_metadata.num_prefill_tokens
-        num_decode_tokens = attn_metadata.num_decode_tokens
-        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
-        assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+            if (key is not None) and (value is not None):
+
+                if attn_type == AttentionType.ENCODER_DECODER:
+                    # Update cross-attention KV cache (prefill-only)
+                    # During cross-attention decode, key & value will be None,
+                    # preventing this IF-statement branch from running
+                    updated_slot_mapping = attn_metadata.cross_slot_mapping
+                else:
+                    # Update self-attention KV cache (prefill/decode)
+                    updated_slot_mapping = attn_metadata.slot_mapping
+
+                # Reshape the input keys and values and store them in the cache.
+                # If kv_cache is not provided, the new key and value tensors are
+                # not cached. This happens during the initial memory
+                # profiling run.
+                PagedAttention.write_to_paged_cache(key, value, key_cache,
+                                                    value_cache,
+                                                    updated_slot_mapping,
+                                                    self.kv_cache_dtype,
+                                                    k_scale, v_scale)
+
+        if attn_type != AttentionType.ENCODER:
+            # Decoder self-attention supports chunked prefill.
+            # Encoder/decoder cross-attention requires no chunked
+            # prefill (100% prefill or 100% decode tokens, no mix)
+            num_prefill_tokens = attn_metadata.num_prefill_tokens
+            num_decode_tokens = attn_metadata.num_decode_tokens
+        else:
+            # Encoder attention - chunked prefill is not applicable;
+            # derive token-count from query shape & and treat them
+            # as 100% prefill tokens
+            assert attn_metadata.num_encoder_tokens is not None
+            num_prefill_tokens = attn_metadata.num_encoder_tokens
+            num_decode_tokens = 0
+
+        if attn_type == AttentionType.DECODER:
+            # Only enforce this shape-constraint for decoder
+            # self-attention
+            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+            assert value.shape[0] == num_prefill_tokens + num_decode_tokens
 
         output = torch.empty_like(query)
         # Query for decode. KV is not needed because it is already cached.
         decode_query = query[num_prefill_tokens:]
         # QKV for prefill.
         query = query[:num_prefill_tokens]
-        key = key[:num_prefill_tokens]
-        value = value[:num_prefill_tokens]
+        if key is not None and value is not None:
+            key = key[:num_prefill_tokens]
+            value = value[:num_prefill_tokens]
 
         assert query.shape[0] == num_prefill_tokens
         assert decode_query.shape[0] == num_decode_tokens
@@ -230,10 +587,14 @@ class XFormersImpl(AttentionImpl):
                 # block tables are empty if the prompt does not have a cached
                 # prefix.
                 out = self._run_memory_efficient_xformers_forward(
-                    query, key, value, prefill_meta)
+                    query, key, value, prefill_meta, attn_type=attn_type)
                 assert out.shape == output[:num_prefill_tokens].shape
                 output[:num_prefill_tokens] = out
             else:
+
+                assert prefill_meta.query_start_loc is not None
+                assert prefill_meta.max_query_len is not None
+
                 # prefix-enabled attention
                 # TODO: this triton kernel has regression issue (broke) to
                 # deal with different data types between KV and FP8 KV cache,
@@ -245,28 +606,37 @@ class XFormersImpl(AttentionImpl):
                     key_cache,
                     value_cache,
                     prefill_meta.block_tables,
-                    prefill_meta.subquery_start_loc,
-                    prefill_meta.prompt_lens_tensor,
-                    prefill_meta.context_lens,
-                    prefill_meta.max_subquery_len,
+                    prefill_meta.query_start_loc,
+                    prefill_meta.seq_lens_tensor,
+                    prefill_meta.context_lens_tensor,
+                    prefill_meta.max_query_len,
                     self.alibi_slopes,
+                    self.sliding_window,
                 )
                 assert output[:num_prefill_tokens].shape == out.shape
                 output[:num_prefill_tokens] = out
 
         if decode_meta := attn_metadata.decode_metadata:
+
+            (
+                seq_lens_arg,
+                max_seq_len_arg,
+                block_tables_arg,
+            ) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
+
             output[num_prefill_tokens:] = PagedAttention.forward_decode(
                 decode_query,
                 key_cache,
                 value_cache,
-                decode_meta.block_tables,
-                decode_meta.context_lens,
-                decode_meta.max_context_len,
-                attn_metadata.kv_cache_dtype,
+                block_tables_arg,
+                seq_lens_arg,
+                max_seq_len_arg,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
-                kv_scale,
+                k_scale,
+                v_scale,
             )
 
         # Reshape the output tensor.
@@ -278,6 +648,7 @@ class XFormersImpl(AttentionImpl):
         key: torch.Tensor,
         value: torch.Tensor,
         attn_metadata: XFormersMetadata,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
         """Attention for 1D query of multiple prompts. Multiple prompt
         tokens are flattened in to `query` input.
@@ -291,7 +662,12 @@ class XFormersImpl(AttentionImpl):
             key: shape = [num_prefill_tokens, num_kv_heads, head_size]
             value: shape = [num_prefill_tokens, num_kv_heads, head_size]
             attn_metadata: Metadata for attention.
+            attn_type: Select attention type, between encoder attention,
+                       decoder self-attention, or encoder/decoder cross-
+                       attention. Defaults to decoder self-attention,
+                       which is the Aphrodite default generally
         """
+
         original_query = query
         if self.num_kv_heads != self.num_heads:
             # GQA/MQA requires the shape [B, M, G, H, K].
@@ -309,18 +685,39 @@ class XFormersImpl(AttentionImpl):
         # Set attention bias if not provided. This typically happens at
         # the very attention layer of every iteration.
         # FIXME: This is a hack.
-        if attn_metadata.attn_bias is None:
+        attn_bias = _get_attn_bias(attn_metadata, attn_type)
+        if attn_bias is None:
             if self.alibi_slopes is None:
-                attn_bias = BlockDiagonalCausalMask.from_seqlens(
-                    attn_metadata.prompt_lens)
+                if (attn_type == AttentionType.ENCODER_DECODER):
+                    assert attn_metadata.seq_lens is not None
+                    assert attn_metadata.encoder_seq_lens is not None
+
+                    # Default enc/dec cross-attention mask is non-causal
+                    attn_bias = BlockDiagonalMask.from_seqlens(
+                        attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
+                elif attn_type == AttentionType.ENCODER:
+                    assert attn_metadata.encoder_seq_lens is not None
+
+                    # Default encoder self-attention mask is non-causal
+                    attn_bias = BlockDiagonalMask.from_seqlens(
+                        attn_metadata.encoder_seq_lens)
+                else:
+                    assert attn_metadata.seq_lens is not None
+
+                    # Default decoder self-attention mask is causal
+                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
+                        attn_metadata.seq_lens)
                 if self.sliding_window is not None:
                     attn_bias = attn_bias.make_local_attention(
                         self.sliding_window)
-                attn_metadata.attn_bias = [attn_bias]
+                attn_bias = [attn_bias]
             else:
-                attn_metadata.attn_bias = _make_alibi_bias(
-                    self.alibi_slopes, self.num_kv_heads, query.dtype,
-                    attn_metadata.prompt_lens)
+                assert attn_metadata.seq_lens is not None
+                attn_bias = _make_alibi_bias(self.alibi_slopes,
+                                             self.num_kv_heads, query.dtype,
+                                             attn_metadata.seq_lens)
+
+            _set_attn_bias(attn_metadata, attn_bias, attn_type)
 
         # No alibi slopes.
         # TODO: Too many view operations. Let's try to reduce
@@ -334,7 +731,7 @@ class XFormersImpl(AttentionImpl):
                 query,
                 key,
                 value,
-                attn_bias=attn_metadata.attn_bias[0],
+                attn_bias=attn_bias[0],
                 p=0.0,
                 scale=self.scale)
             return out.view_as(original_query)
@@ -343,20 +740,21 @@ class XFormersImpl(AttentionImpl):
         # FIXME: Because xformers does not support dynamic sequence
         # lengths with custom attention bias, we process each prompt one by
         # one. This is inefficient, especially when we have many short prompts.
+        assert attn_metadata.seq_lens is not None
         output = torch.empty_like(original_query)
         start = 0
-        for i, prompt_len in enumerate(attn_metadata.prompt_lens):
-            end = start + prompt_len
+        for i, seq_len in enumerate(attn_metadata.seq_lens):
+            end = start + seq_len
             out = xops.memory_efficient_attention_forward(
                 query[None, start:end],
                 key[None, start:end],
                 value[None, start:end],
-                attn_bias=attn_metadata.attn_bias[i],
+                attn_bias=attn_bias[i],
                 p=0.0,
                 scale=self.scale)
             # TODO: Unnecessary copy. Optimize.
             output[start:end].copy_(out.view_as(original_query[start:end]))
-            start += prompt_len
+            start += seq_len
         return output
 
 
@@ -364,13 +762,13 @@ def _make_alibi_bias(
     alibi_slopes: torch.Tensor,
     num_kv_heads: int,
     dtype: torch.dtype,
-    prompt_lens: List[int],
-) -> LowerTriangularMaskWithTensorBias:
-    attn_biases = []
-    for prompt_len in prompt_lens:
-        bias = torch.arange(prompt_len, dtype=dtype)
+    seq_lens: List[int],
+) -> List[AttentionBias]:
+    attn_biases: List[AttentionBias] = []
+    for seq_len in seq_lens:
+        bias = torch.arange(seq_len, dtype=dtype)
         # NOTE: HF uses
-        #     `bias = bias[None, :].repeat(prompt_len, 1)`
+        #     `bias = bias[None, :].repeat(seq_len, 1)`
         # here. We find that both biases give the same results, but
         # the bias below more accurately follows the original ALiBi
         # paper.
@@ -378,16 +776,16 @@ def _make_alibi_bias(
         # element.
         bias = bias[None, :] - bias[:, None]
 
-        padded_len = (prompt_len + 7) // 8 * 8
+        padded_len = (seq_len + 7) // 8 * 8
         num_heads = alibi_slopes.shape[0]
         bias = torch.empty(
             1,  # batch size
             num_heads,
-            prompt_len,
+            seq_len,
             padded_len,
             device=alibi_slopes.device,
             dtype=dtype,
-        )[:, :, :, :prompt_len].copy_(bias)
+        )[:, :, :, :seq_len].copy_(bias)
         bias.mul_(alibi_slopes[:, None, None])
         if num_heads != num_kv_heads:
             bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))

+ 76 - 10
aphrodite/attention/layer.py

@@ -1,19 +1,24 @@
 """Attention layer."""
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
 
 import torch
 import torch.nn as nn
 
 from aphrodite.attention.backends.abstract import (AttentionMetadata,
-                                                   AttentionMetadataPerStage)
+                                                   AttentionType)
 from aphrodite.attention.selector import get_attn_backend
+from aphrodite.common.config import CacheConfig
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.kv_cache import BaseKVCacheMethod
 
 
 class Attention(nn.Module):
     """Attention layer.
+
     This class takes query, key, and value tensors as input. The input tensors
     can either contain prompt tokens or generation tokens.
     The class does the following:
+
     1. Store the input key and value tensors in the KV cache.
     2. Perform (multi-head/multi-query/grouped-query) attention.
     3. Return the output tensor.
@@ -26,13 +31,59 @@ class Attention(nn.Module):
         scale: float,
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+        prefix: str = "",
     ) -> None:
         super().__init__()
-        self.backend = get_attn_backend(torch.get_default_dtype())
-        impl_cls = self.backend.get_impl_cls()
+        if cache_config is not None:
+            kv_cache_dtype = cache_config.cache_dtype
+            block_size = cache_config.block_size
+            sliding_window = cache_config.sliding_window
+        else:
+            kv_cache_dtype = "auto"
+            block_size = 16
+            sliding_window = None
+        if num_kv_heads is None:
+            num_kv_heads = num_heads
+
+        # The default k/v_scale is set to 1.0. This is ignored
+        # when kv-cache is not fp8, and should be used with
+        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
+        # expect the pre-quantized k/v_scale to be loaded along
+        # with the model weights.
+        self.kv_cache_dtype = kv_cache_dtype
+        self._k_scale = 1.0
+        self._v_scale = 1.0
+        quant_method = quant_config.get_quant_method(
+            self, prefix=prefix) if quant_config else None
+        if quant_method is not None:
+            assert isinstance(quant_method, BaseKVCacheMethod)
+            # TODO: kv cache dtype should be specified in the FP8
+            # checkpoint config and become the "auto" behavior
+            if self.kv_cache_dtype == "fp8_e5m2":
+                raise ValueError("fp8_e5m2 kv-cache is not supported with "
+                                 "fp8 checkpoints.")
+            # If quantization is enabled, we make "k_scale" and "v_scale"
+            # parameters so that it can be loaded from the model checkpoint.
+            # The k/v_scale will then be converted back to native float32
+            # values after weight loading.
+            self.quant_method = quant_method
+            self.quant_method.create_weights(self)
+
+        # During model initialization, the default dtype is set as the model
+        # weight and activation dtype.
+        dtype = torch.get_default_dtype()
+        attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
+                                        sliding_window, dtype, kv_cache_dtype,
+                                        block_size, blocksparse_params
+                                        is not None)
+        impl_cls = attn_backend.get_impl_cls()
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
-                             alibi_slopes, sliding_window)
+                             alibi_slopes, sliding_window, kv_cache_dtype,
+                             blocksparse_params, logits_soft_cap)
 
     def forward(
         self,
@@ -40,8 +91,23 @@ class Attention(nn.Module):
         key: torch.Tensor,
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
-        attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
-        kv_scale: float = 1.0,
+        attn_metadata: AttentionMetadata,
+        attn_type: AttentionType = AttentionType.DECODER,
     ) -> torch.Tensor:
-        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
-                                 kv_scale)
+
+        return self.impl.forward(query,
+                                 key,
+                                 value,
+                                 kv_cache,
+                                 attn_metadata,
+                                 self._k_scale,
+                                 self._v_scale,
+                                 attn_type=attn_type)
+
+    def extra_repr(self) -> str:
+        s = f"head_size={self.impl.head_size}"  # type: ignore
+        s += f", num_heads={self.impl.num_heads}"  # type: ignore
+        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
+        s += f", scale={self.impl.scale}"  # type: ignore
+        s += f", backend={self.impl.__class__.__name__}"
+        return s

+ 0 - 0
aphrodite/attention/ops/blocksparse_attention/__init__.py


+ 422 - 0
aphrodite/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

@@ -0,0 +1,422 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def blocksparse_flash_attn_varlen_fwd(
+        q,
+        k,
+        v,  # (#tokens, n_heads, head_size)
+        cu_seqlens_k,
+        cu_seqlens_q,
+        sm_scale,
+        sparse_layout,
+        *,
+        block_size=64,
+        q_block_size=None,
+        max_seqlen=None):
+    # split q to blocks
+
+    assert isinstance(sparse_layout, (list, tuple))
+
+    _, n_heads, head_size = q.shape
+    batch_size = cu_seqlens_k.size(0) - 1
+    q_block_size = q_block_size or block_size
+
+    assert q.dim() == k.dim() == v.dim() == 3
+    assert q.size(1) % k.size(1) == 0
+    assert q.size(2) == k.size(2)
+    # TODO: allow k, v to have different head_size
+    assert k.shape == v.shape
+    assert cu_seqlens_k.dim() == 1
+
+    q_k_ratio = q.size(1) // k.size(1)
+
+    if cu_seqlens_q is None:
+        if q.size(0) == batch_size:  # decoding only
+            cu_seqlens_q = torch.arange(
+                0,
+                batch_size + 1,
+                dtype=cu_seqlens_k.dtype,
+                device=cu_seqlens_k.device,
+            )
+        elif q.size(0) == k.size(0):
+            cu_seqlens_q = cu_seqlens_k
+        else:
+            raise ValueError("cu_seqlens_q must be specified\
+                    if it mix of prefilling and decoding.")
+    else:
+        assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
+
+    # switch to use cpu to avoid too many kernel launches when iterated over
+    q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
+    k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
+
+    assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
+        "length of q should either be 1 (decoding) or same as k (prefilling).")
+
+    if max_seqlen:
+        assert k_lens.max() <= max_seqlen
+
+    n_blocks = (q_lens + q_block_size - 1) // q_block_size
+
+    q_batch_ids = torch.tensor(
+        [i for i, n in enumerate(n_blocks) for _ in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+    q_start_sids = torch.tensor(
+        [i * q_block_size for n in n_blocks for i in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+
+    out = q.new_empty(q.shape)
+    cu_seqlens_q = cu_seqlens_q.contiguous()
+    cu_seqlens_k = cu_seqlens_k.contiguous()
+
+    layout_crow_indices, layout_col_indices = sparse_layout
+    block_d = triton.next_power_of_2(head_size)
+
+    decoding_only = (q_lens == 1).all().item()
+    grid = (len(q_start_sids), n_heads, 1)
+
+    _fwd_kernel_batch_inference[grid](
+        q,
+        k,
+        v,
+        out,
+        sm_scale,
+        cu_seqlens_q[:-1],
+        cu_seqlens_q[1:],
+        cu_seqlens_k[:-1],
+        cu_seqlens_k[1:],
+        q_batch_ids,
+        q_start_sids,
+        0,
+        *q.stride(),
+        0,
+        *k.stride(),
+        0,
+        *v.stride(),
+        0,
+        *out.stride(),
+        layout_crow_indices,
+        layout_col_indices,
+        *layout_crow_indices.stride(),
+        *layout_col_indices.stride(),
+        q_k_ratio,
+        HAS_BATCH_DIM=False,
+        D_HEAD=head_size,
+        BLOCK_M=q_block_size,
+        BLOCK_N=block_size,
+        BLOCK_D=block_d,
+        BLOCK_M_LOADING=(16 if decoding_only else
+                         q_block_size),  # smaller for decoding
+        EVEN_D=block_d == head_size,
+        num_warps=1 if decoding_only else 4,
+        num_stages=3)
+
+    return out
+
+
+@triton.jit
+def _fwd_kernel_inner(
+    acc,
+    l_i,
+    m_i,
+    q,
+    Q,
+    k_block_col_idx,
+    layout_col_ptr,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    k_ptrs,
+    v_ptrs,
+    off_h,
+    offs_m,
+    offs_n,
+    offs_d,
+    stride_kt,
+    stride_vt,
+    sm_scale,
+    k_seqlen,
+    past_len,
+    LAST_K_BLOCK: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
+                         k_block_col_idx * layout_col_stride_m).to(tl.int32)
+    start_n = k_block_id * BLOCK_N
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=offs_n[None, :] + start_n < k_seqlen,
+            )
+        else:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=(offs_n[None, :] + start_n < k_seqlen) &
+                (offs_d[:, None] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            k = tl.load(k_ptrs + start_n * stride_kt)
+        else:
+            k = tl.load(k_ptrs + start_n * stride_kt,
+                        mask=offs_d[:, None] < D_HEAD)
+
+    qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
+    qk += tl.dot(q, k)
+    qk *= sm_scale
+
+    # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
+    if LAST_K_BLOCK | M_LT_N:
+        qk += tl.where(
+            offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
+            0,
+            float("-inf"),
+        )
+
+    # flash-attn2
+    m_ij = tl.maximum(m_i, tl.max(qk, 1))
+    p = tl.math.exp2(qk - m_ij[:, None])
+    l_ij = tl.sum(p, 1)
+    alpha = tl.math.exp2(m_i - m_ij)
+    acc = acc * alpha[:, None]
+    # update m_i
+    m_i = m_ij
+    l_i = l_i * alpha + l_ij
+
+    p = p.to(Q.dtype.element_ty)
+    # update acc
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=offs_n[:, None] + start_n < k_seqlen,
+            )
+        else:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=(offs_n[:, None] + start_n < k_seqlen) &
+                (offs_d[None, :] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            v = tl.load(v_ptrs + start_n * stride_vt)
+        else:
+            v = tl.load(v_ptrs + start_n * stride_vt,
+                        mask=offs_d[None, :] < D_HEAD)
+
+    acc += tl.dot(p, v)
+
+    return acc, l_i, m_i
+
+
+@triton.heuristics({
+    "M_LT_N":
+    lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
+})
+@triton.jit
+def _fwd_kernel_batch_inference(
+    Q,
+    K,
+    V,
+    Out,
+    sm_scale,
+    q_batch_starts,
+    q_batch_ends,
+    k_batch_starts,
+    k_batch_ends,
+    q_batch_ids,
+    q_start_sids,
+    stride_qb,
+    stride_qt,
+    stride_qh,
+    stride_qd,
+    stride_kb,
+    stride_kt,
+    stride_kh,
+    stride_kd,
+    stride_vb,
+    stride_vt,
+    stride_vh,
+    stride_vd,
+    stride_ob,
+    stride_ot,
+    stride_oh,
+    stride_od,
+    layout_crow_ptr,
+    layout_col_ptr,
+    layout_crow_stride_h,
+    layout_crow_stride_m,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    q_k_ratio,
+    HAS_BATCH_DIM: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    """
+    NOTATION:
+    pid: position id
+    sid: storage id
+    sbid: storage block id
+    pbid: position block id
+    offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
+    TODO:
+    Optimize grouped-attn
+    """
+    off_zm = tl.program_id(0)
+    off_h = tl.program_id(1)
+
+    off_h_for_kv = off_h // q_k_ratio
+
+    if HAS_BATCH_DIM:
+        off_z = tl.program_id(2)
+        Q += off_z * stride_qb
+        K += off_z * stride_kb
+        V += off_z * stride_vb
+        Out += off_z * stride_ob
+        start_m = off_zm
+        q_start_sid = start_m * BLOCK_M  # always 0 for decoding
+    else:
+        off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)  # [0, 0, 0, 1]
+        q_start_sid = tl.load(q_start_sids + off_zm)
+        start_m = q_start_sid // BLOCK_M  # q_sbid
+
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
+    offs_n = tl.arange(0, BLOCK_N)
+    offs_d = tl.arange(0, BLOCK_D)
+
+    q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
+    q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
+    k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
+    k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
+    past_len = k_seqlen - q_seqlen
+
+    Q += q_cu_start * stride_qt + off_h * stride_qh
+    K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
+    V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
+    Out += q_cu_start * stride_ot + off_h * stride_oh
+
+    q_pbid = (past_len + q_start_sid) // BLOCK_M
+
+    if EVEN_D:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+            other=0,
+        )
+
+    sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
+                       q_pbid * layout_crow_stride_m)
+
+    # TODO: load at once, with any Triton version
+    # that supports `tl.split`, e.g., Triton 3.0
+    k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
+    k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
+
+    m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
+    l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
+
+    k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
+    v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
+
+    sm_scale *= (
+        1.44269504  # 1/log2 as we use base2 for exponential and logarithm
+    )
+
+    for k_block_col_idx in range(k_block_start, k_block_end - 1):
+        acc, l_i, m_i = _fwd_kernel_inner(
+            acc,
+            l_i,
+            m_i,
+            q,
+            Q,
+            k_block_col_idx,
+            layout_col_ptr,
+            layout_col_stride_h,
+            layout_col_stride_m,
+            k_ptrs,
+            v_ptrs,
+            off_h,
+            offs_m,
+            offs_n,
+            offs_d,
+            stride_kt,
+            stride_vt,
+            sm_scale,
+            k_seqlen,
+            past_len,
+            False,
+            BLOCK_M_LOADING,
+            BLOCK_N,
+            D_HEAD,
+            EVEN_D,
+            M_LT_N,
+        )
+
+    acc, l_i, m_i = _fwd_kernel_inner(
+        acc,
+        l_i,
+        m_i,
+        q,
+        Q,
+        k_block_end - 1,
+        layout_col_ptr,
+        layout_col_stride_h,
+        layout_col_stride_m,
+        k_ptrs,
+        v_ptrs,
+        off_h,
+        offs_m,
+        offs_n,
+        offs_d,
+        stride_kt,
+        stride_vt,
+        sm_scale,
+        k_seqlen,
+        past_len,
+        True,
+        BLOCK_M_LOADING,
+        BLOCK_N,
+        D_HEAD,
+        EVEN_D,
+        M_LT_N,
+    )
+
+    # flash-attn 2
+    m_i += tl.math.log2(l_i)
+    acc = acc / l_i[:, None]
+
+    # write output
+    if EVEN_D:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+        )

+ 236 - 0
aphrodite/attention/ops/blocksparse_attention/interface.py

@@ -0,0 +1,236 @@
+import math
+
+import torch
+
+from aphrodite.attention.ops.blocksparse_attention.utils import (
+    dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask)
+from aphrodite.common.utils import is_cpu, is_hip
+from aphrodite.platforms import current_platform
+
+IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
+                         and current_platform.get_device_capability()[0] >= 8)
+
+if IS_COMPUTE_8_OR_ABOVE:
+    from aphrodite.attention.ops.blocksparse_attention.blocksparse_attention_kernel import (  # noqa: E501
+        blocksparse_flash_attn_varlen_fwd)
+
+
+class LocalStridedBlockSparseAttn(torch.nn.Module):
+
+    def __init__(
+        self,
+        n_heads,
+        max_seqlen,
+        local_blocks,
+        vert_stride,
+        block_size,
+        device=None,
+        dtype=None,
+        homo_head=False,
+        active_head_range=None,
+        q_block_size=None,
+        use_spda=None,
+    ):
+        super().__init__()
+        if use_spda is None:
+            use_spda = is_hip() or is_cpu() or not \
+                       IS_COMPUTE_8_OR_ABOVE
+        device = device or (torch.cuda.current_device()
+                            if torch.cuda.is_available() else "cpu")
+        device = torch.device(device)
+        # NOTE: aphrodite CPU backend support BF16 instead of FP16.
+        dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
+                          or device.type == "cpu" else torch.half)
+
+        self.n_heads = n_heads
+        self.max_seqlen = max_seqlen
+        self.local_blocks = local_blocks
+        self.vert_stride = vert_stride
+        self.use_spda = use_spda
+        self.dtype = dtype
+        self.device = device
+        self.block_size = block_size
+        self.q_block_size = q_block_size
+        self.homo_head = homo_head
+        self.active_head_range = active_head_range
+        self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
+                                                       homo_head)
+
+        sparse_layout, sparse_pattern, self.dense_attn_mask = (
+            self.get_attn_pattern(dtype, device))
+
+        if q_block_size is not None and q_block_size != block_size:
+            if q_block_size > block_size:
+                assert q_block_size % block_size == 0
+                blocks_to_merge = q_block_size // block_size
+                shape = sparse_pattern.shape
+                sparse_pattern = sparse_pattern.view(shape[0], -1,
+                                                     blocks_to_merge,
+                                                     shape[-1])
+                sparse_pattern = sparse_pattern.sum(2)
+                sparse_layout = dense_to_crow_col(sparse_pattern)
+            else:
+                raise ValueError(
+                    "Does not support smaller q_block_size. It will be slower."
+                )
+
+        self.sparse_layout = sparse_layout
+
+    def get_attn_pattern(self, dtype, device):
+        sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
+            self.n_heads,
+            self.max_seqlen,
+            self.max_seqlen,
+            dtype,
+            device,
+            block_size=self.block_size,
+            local_blocks=self.local_blocks,
+            vert_stride=self.vert_stride,
+            homo_head=self.homo_head,
+            return_dense=self.use_spda,
+            dense_mask_type="bias",
+        )
+        if (not self.homo_head) and (self.active_head_range is not None):
+            assert isinstance(self.active_head_range, tuple)
+            assert (len(self.active_head_range) == 2)
+            h_start, h_end = self.active_head_range
+            sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
+            if self.use_spda:
+                dense_attn_mask = dense_attn_mask[h_start:h_end]
+        return sparse_layout, sparse_pattern, dense_attn_mask
+
+    def varlen_attn(self,
+                    q,
+                    k,
+                    v,
+                    cu_seqlens_k,
+                    cu_seqlens_q=None,
+                    sm_scale=None):
+        """
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+        Support grouped attention, with `q[:, i*r:(i*r + r)]`
+        is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), 
+        indicating segment of samples, 
+        e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+        Default None: same as cu_seqlens_k for prefilling or
+        [0, 1, .., batch_size] for decoding.
+        The only case you need to specify is when q is a mix of 
+        prefilling and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert (
+            IS_COMPUTE_8_OR_ABOVE
+        ), "Requires compute capability of 8 or above (Ampere or newer) to use \
+            Triton kernel."
+
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+
+        return blocksparse_flash_attn_varlen_fwd(
+            q,
+            k,
+            v,
+            cu_seqlens_k,
+            cu_seqlens_q,
+            sm_scale,
+            self.sparse_layout,
+            block_size=self.block_size,
+            q_block_size=self.q_block_size,
+            max_seqlen=self.max_seqlen,
+        )
+
+    @staticmethod
+    def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
+        """
+        :param x: (total_tokens, n_heads, head_size)
+        :return: (batch, n_heads, length, head_size)
+        """
+        x_padded = x.new_empty(
+            len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
+        cu_seqlens = cu_seqlens.cpu()
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
+                                                             1).unsqueeze(1))
+        return x_padded.flatten(1, 2)
+
+    @staticmethod
+    def transpose_and_unpad(x_padded, cu_seqlens):
+        """
+        :param x_padded: (batch, n_heads, length, head_size)
+        :return: (total_tokens, n_heads, head_size)
+        """
+        cu_seqlens = cu_seqlens.cpu()
+        total_n_tokens = cu_seqlens[-1]
+        x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
+                               x_padded.size(3))
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
+        return x
+
+    def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """For CPU, V100 or other older GPUs.
+        NOTE: torch SPDA supports nested tensor, 
+        but seems extremely slow. Choose to pad instead.
+        """
+        assert (cu_seqlens_q is None or
+                (cu_seqlens_q
+                 == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
+        assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
+
+        assert q.size(1) % k.size(1) == 0
+        q_k_ratio = q.size(1) // k.size(1)
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+        cu_seqlens = cu_seqlens_k.cpu()
+        maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+
+        if (self.dense_attn_mask.dtype != q.dtype
+                or self.dense_attn_mask.device != q.device):
+            _, _, self.dense_attn_mask = self.get_attn_pattern(
+                q.dtype, q.device)
+        attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
+
+        q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
+        k2, v2 = [
+            self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
+            for x in [k, v]
+        ]
+        spda_output = torch.nn.functional.scaled_dot_product_attention(
+            q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
+        return self.transpose_and_unpad(spda_output, cu_seqlens)
+
+    def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """Dispatch to `varlen_attn` (Ampere or newer) or 
+        `self.spda`(cpu, Volta, Turing or older)based on 
+        the type of device used and cuda compute capability.
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+                Support grouped attention, with `q[:, i*r:(i*r + r)]`
+                is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
+                    e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+                    Default None: same as cu_seqlens_k for prefilling or
+                    [0, 1, .., batch_size] for decoding.
+                    The only case you need to specify 
+                    is when q is a mix of prefilling 
+                    and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert k.dim() == 3
+        if self.use_spda:
+            return self.spda(
+                q,
+                k,
+                v,
+                cu_seqlens_k,
+                cu_seqlens_q=cu_seqlens_q,
+                sm_scale=sm_scale,
+            )
+        return self.varlen_attn(q,
+                                k,
+                                v,
+                                cu_seqlens_k,
+                                cu_seqlens_q=cu_seqlens_q,
+                                sm_scale=sm_scale)

+ 242 - 0
aphrodite/attention/ops/blocksparse_attention/utils.py

@@ -0,0 +1,242 @@
+# Helper functions for 3D sparse pattern
+# These function are not optimized and very inefficient.
+# Avoid calling them too frequent or use a cache mechanism.
+
+from functools import lru_cache
+
+import numpy as np
+import torch
+import triton
+
+
+class csr_matrix:
+    """Simple implementation of CSR matrix conversion without scipy.
+    This replaced scipy.sparse.csr_matrix() previously used."""
+
+    def __init__(self, input_array):
+        if not isinstance(input_array, np.ndarray):
+            raise ValueError("Input must be a NumPy array")
+
+        self.shape = input_array.shape
+        rows, cols = self.shape
+        data = []
+        indices = []
+        indptr = [0]
+
+        for i in range(rows):
+            for j in range(cols):
+                if input_array[i, j]:
+                    data.append(input_array[i, j])
+                    indices.append(j)
+            indptr.append(len(indices))
+
+        self.data = np.array(data)
+        self.indices = np.array(indices)
+        self.indptr = np.array(indptr)
+
+
+def dense_to_crow_col(x: torch.Tensor):
+    """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
+    NOTE: col_indices padded -1
+    """
+    device = x.device
+    pad = -1
+    dim = x.dim()
+    assert x.dim() in (2, 3)
+    if x.dim() == 2:
+        x = x[None]
+    x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
+    crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
+    cols = [torch.from_numpy(xi.indices) for xi in x]
+    max_cols = max(len(xi) for xi in cols)
+    cols = [
+        torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
+        for xi in cols
+    ]
+    cols = torch.vstack(cols)
+    if dim == 2:
+        crows = crows[0]
+        cols = cols[0]
+    return crows.to(device), cols.to(device)
+
+
+def crow_col_to_dense(crows: torch.Tensor,
+                      cols: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    dim = crows.dim()
+    if dim == 1:
+        crows = crows[None]
+        cols = cols[None]
+    device = crows.device
+    crows, cols = crows.cpu(), cols.cpu()  # faster in cpu
+    shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
+    x = torch.zeros(shape, dtype=dtype)
+    for i in range(shape[0]):
+        for j in range(shape[1]):
+            x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
+    if dim == 1:
+        x = x[0]
+    return x.to(device)
+
+
+def dense_to_ccol_row(x: torch.Tensor):
+    """Similar, but to CSC format"""
+    x = x.transpose(-2, -1)
+    return dense_to_crow_col(x)
+
+
+def ccol_row_to_dense(ccol: torch.Tensor,
+                      rows: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
+
+
+def _get_sparse_attn_mask_homo_head(
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 128,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    return_dense: bool = False,
+):
+    """
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be 
+            OOM if it is too big) if `return_dense==True`, 
+            otherwise, None
+    """
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[:, None]
+        k_pos = torch.arange(num_blocks)[None]
+        mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = (dense_to_crow_col(
+            block_mask_dense[-num_blocks_q:].contiguous()))
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            None,
+        )
+
+
+def binary_mask_to_bias(mask_dense: torch.Tensor):
+    mask_dense = 1 - mask_dense
+    mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
+    return mask_dense
+
+
+def get_head_sliding_step(n_heads: int,
+                          vert_stride: int,
+                          homo_head: bool = False):
+    if homo_head:
+        return 0
+    return max(1, int(vert_stride / n_heads))
+
+
+@lru_cache
+def get_sparse_attn_mask(
+    n_heads: int,
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 64,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    homo_head: bool = True,
+    return_dense: bool = False,
+    dense_mask_type: str = "binary",
+):
+    """
+    :param dense_mask_type: "binary" (0 for skip token, 1 for others)
+        or "bias" (-inf for skip token, 0 or others)
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be OOM if it 
+            is too big) if `return_dense==True`, otherwise, None
+    """
+    assert dense_mask_type in ("binary", "bias")
+    if homo_head:
+        with torch.no_grad():
+            (crow, col), block_mask_dense, mask_dense = (
+                _get_sparse_attn_mask_homo_head(
+                    q_len,
+                    max_seqlen,
+                    dtype,
+                    device,
+                    block_size,
+                    local_blocks,
+                    vert_stride,
+                    return_dense,
+                ))
+            crow = crow[None].expand(n_heads, crow.shape[0])
+            col = col[None].expand(n_heads, col.shape[0])
+            if return_dense:
+                mask_dense = mask_dense[None].expand(n_heads,
+                                                     *mask_dense.shape)
+                if dense_mask_type == "bias":
+                    mask_dense = binary_mask_to_bias(mask_dense)
+            return (crow, col), block_mask_dense, mask_dense
+
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[None, :, None]
+        k_pos = torch.arange(num_blocks)[None, None]
+        head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
+        mask_vert_strided = [
+            (torch.arange(num_blocks) + h * head_sliding_step + 1) %
+            vert_stride == 0 for h in range(n_heads)
+        ]
+        mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
+        if dense_mask_type == "bias":
+            mask_dense = binary_mask_to_bias(mask_dense)
+
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            None,
+        )

+ 122 - 0
aphrodite/attention/ops/ipex_attn.py

@@ -0,0 +1,122 @@
+from typing import Dict, List, Optional, Tuple
+
+import intel_extension_for_pytorch.llm.modules as ipex_modules
+import torch
+
+from aphrodite import _custom_ops as ops
+
+
+class PagedAttention:
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [64, 80, 96, 112, 128, 256]
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+        *args,
+    ) -> Tuple[int, ...]:
+        return (2, num_blocks, block_size * num_kv_heads * head_size)
+
+    @staticmethod
+    def split_kv_cache(
+        kv_cache: torch.Tensor,
+        num_kv_heads: int,
+        head_size: int,
+        *args,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        num_blocks = kv_cache.shape[1]
+
+        key_cache = kv_cache[0]
+        key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
+        value_cache = kv_cache[1]
+        value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
+        return key_cache, value_cache
+
+    @staticmethod
+    def write_to_paged_cache(
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        slot_mapping: torch.Tensor,
+        kv_cache_dtype: str,
+        k_scale: float,
+        v_scale: float,
+        *args,
+    ) -> None:
+        ipex_modules.PagedAttention.reshape_and_cache(
+            key, value, key_cache, value_cache,
+            slot_mapping.flatten().int())
+
+    @staticmethod
+    def forward_decode(
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        block_tables: torch.Tensor,
+        context_lens: torch.Tensor,
+        max_context_len: int,
+        kv_cache_dtype: str,
+        num_kv_heads: int,
+        scale: float,
+        alibi_slopes: Optional[torch.Tensor],
+        k_scale: float,
+        v_scale: float,
+        *args,
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+        block_size = value_cache.shape[2]
+        head_mapping = torch.arange(
+            0,
+            num_kv_heads,
+            device="cpu",
+            dtype=torch.int32,
+        ).view(num_kv_heads,
+               1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
+        ipex_modules.PagedAttention.single_query_cached_kv_attention(
+            output, query.contiguous(), key_cache, value_cache, head_mapping,
+            scale, block_tables, context_lens, block_size, max_context_len,
+            alibi_slopes)
+
+        return output
+
+    @staticmethod
+    def forward_prefix(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        block_tables: torch.Tensor,
+        subquery_start_loc: torch.Tensor,
+        prompt_lens_tensor: torch.Tensor,
+        context_lens: torch.Tensor,
+        max_subquery_len: int,
+        alibi_slopes: Optional[torch.Tensor],
+        *args,
+    ) -> torch.Tensor:
+        raise NotImplementedError
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+        *args,
+    ) -> None:
+        raise NotImplementedError
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+        *args,
+    ) -> None:
+        key_caches = [kv_cache[0] for kv_cache in kv_caches]
+        value_caches = [kv_cache[1] for kv_cache in kv_caches]
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)

+ 68 - 37
aphrodite/attention/ops/paged_attn.py

@@ -1,11 +1,13 @@
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
 
 import torch
 
-from aphrodite._C import cache_ops
-from aphrodite._C import ops
-from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
+from aphrodite import _custom_ops as ops
+from aphrodite.triton_utils import HAS_TRITON
+
+if HAS_TRITON:
+    from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
 
 # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
 _PARTITION_SIZE = 512
@@ -14,12 +16,11 @@ _PARTITION_SIZE = 512
 @dataclass
 class PagedAttentionMetadata:
     """Metadata for PagedAttention."""
-    # (batch_size,). The length of context (tokens stored in KV cache) per
-    # sequence. WARNING: When it is a prefill request, it doesn't include new
-    # tokens. When it is for decoding, it includes a new token.
-    context_lens: Optional[torch.Tensor]
-    # Maximum context length in the batch.
-    max_context_len: Optional[int]
+    # (batch_size,). The length of sequences (entire tokens seen so far) per
+    # sequence.
+    seq_lens_tensor: Optional[torch.Tensor]
+    # Maximum sequence length in the batch. 0 if it is prefill-only batch.
+    max_decode_seq_len: int
     # (batch_size, max_blocks_per_seq).
     # Block addresses per sequence. (Seq id -> list of physical block)
     # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
@@ -33,7 +34,7 @@ class PagedAttention:
 
     @staticmethod
     def get_supported_head_sizes() -> List[int]:
-        return [64, 80, 96, 112, 128, 256]
+        return [64, 80, 96, 112, 120, 128, 192, 256]
 
     @staticmethod
     def get_kv_cache_shape(
@@ -68,16 +69,18 @@ class PagedAttention:
         value_cache: torch.Tensor,
         slot_mapping: torch.Tensor,
         kv_cache_dtype: str,
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
     ) -> None:
-        cache_ops.reshape_and_cache(
+        ops.reshape_and_cache(
             key,
             value,
             key_cache,
             value_cache,
             slot_mapping.flatten(),
             kv_cache_dtype,
-            kv_scale,
+            k_scale,
+            v_scale,
         )
 
     @staticmethod
@@ -86,19 +89,33 @@ class PagedAttention:
         key_cache: torch.Tensor,
         value_cache: torch.Tensor,
         block_tables: torch.Tensor,
-        context_lens: torch.Tensor,
-        max_context_len: int,
+        seq_lens: torch.Tensor,
+        max_seq_len: int,
         kv_cache_dtype: str,
         num_kv_heads: int,
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
-        kv_scale: float,
+        k_scale: float,
+        v_scale: float,
+        tp_rank: int = 0,
+        blocksparse_local_blocks: int = 0,
+        blocksparse_vert_stride: int = 0,
+        blocksparse_block_size: int = 64,
+        blocksparse_head_sliding_step: int = 0,
     ) -> torch.Tensor:
+        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
+            # use blocksparse paged attention
+            block_size = value_cache.size(-1)
+            assert (blocksparse_block_size > 0 and
+                    blocksparse_block_size % block_size == 0), \
+                (f"{blocksparse_block_size=} needs to be a multiple of"
+                 f"{block_size=} used in block_tables.")
+
         output = torch.empty_like(query)
 
         block_size = value_cache.shape[3]
         num_seqs, num_heads, head_size = query.shape
-        max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
+        max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
                               _PARTITION_SIZE)
         # NOTE: We use a simple heuristic to decide whether to use
         # PagedAttention V1 or V2. If the number of partitions is 1, we use
@@ -107,7 +124,7 @@ class PagedAttention:
         # to parallelize.
         # TODO: Tune this heuristic.
         # For context len > 8192, use V2 kernel to avoid shared memory shortage.
-        use_v1 = (max_context_len <= 8192
+        use_v1 = (max_seq_len <= 8192
                   and (max_num_partitions == 1 or num_seqs * num_heads > 512))
         if use_v1:
             # Run PagedAttention V1.
@@ -119,12 +136,18 @@ class PagedAttention:
                 num_kv_heads,
                 scale,
                 block_tables,
-                context_lens,
+                seq_lens,
                 block_size,
-                max_context_len,
+                max_seq_len,
                 alibi_slopes,
                 kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
         else:
             # Run PagedAttention V2.
@@ -151,12 +174,18 @@ class PagedAttention:
                 num_kv_heads,
                 scale,
                 block_tables,
-                context_lens,
+                seq_lens,
                 block_size,
-                max_context_len,
+                max_seq_len,
                 alibi_slopes,
                 kv_cache_dtype,
-                kv_scale,
+                k_scale,
+                v_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
         return output
 
@@ -168,11 +197,12 @@ class PagedAttention:
         key_cache: torch.Tensor,
         value_cache: torch.Tensor,
         block_tables: torch.Tensor,
-        subquery_start_loc: torch.Tensor,
-        prompt_lens_tensor: torch.Tensor,
+        query_start_loc: torch.Tensor,
+        seq_lens_tensor: torch.Tensor,
         context_lens: torch.Tensor,
-        max_subquery_len: int,
+        max_query_len: int,
         alibi_slopes: Optional[torch.Tensor],
+        sliding_window: Optional[int],
     ) -> torch.Tensor:
         output = torch.empty_like(query)
         context_attention_fwd(
@@ -183,12 +213,13 @@ class PagedAttention:
             key_cache,
             value_cache,
             block_tables,
-            # subquery_start_loc is (batch_size + 1,)
-            subquery_start_loc[:-1],
-            prompt_lens_tensor,
+            # query_start_loc is (batch_size + 1,)
+            query_start_loc[:-1],
+            seq_lens_tensor,
             context_lens,
-            max_subquery_len,
+            max_query_len,
             alibi_slopes,
+            sliding_window,
         )
         return output
 
@@ -196,21 +227,21 @@ class PagedAttention:
     def swap_blocks(
         src_kv_cache: torch.Tensor,
         dst_kv_cache: torch.Tensor,
-        src_to_dst: Dict[int, int],
+        src_to_dst: torch.Tensor,
     ) -> None:
         src_key_cache = src_kv_cache[0]
         dst_key_cache = dst_kv_cache[0]
-        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
 
         src_value_cache = src_kv_cache[1]
         dst_value_cache = dst_kv_cache[1]
-        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
+        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
-        src_to_dists: Dict[int, List[int]],
+        src_to_dists: torch.Tensor,
     ) -> None:
         key_caches = [kv_cache[0] for kv_cache in kv_caches]
         value_caches = [kv_cache[1] for kv_cache in kv_caches]
-        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
+        ops.copy_blocks(key_caches, value_caches, src_to_dists)

+ 113 - 51
aphrodite/attention/ops/prefix_prefill.py

@@ -5,6 +5,8 @@ import torch
 import triton
 import triton.language as tl
 
+from aphrodite.platforms import current_platform
+
 if triton.__version__ >= "2.1.0":
 
     @triton.jit
@@ -47,8 +49,10 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_bl,
         num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
-        BLOCK_DMODEL: tl.constexpr,
+        BLOCK_DMODEL: tl.constexpr,  # head size
+        BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
         BLOCK_N: tl.constexpr,
+        SLIDING_WINDOW: tl.constexpr,
     ):
         cur_batch = tl.program_id(0)
         cur_head = tl.program_id(1)
@@ -59,64 +63,96 @@ if triton.__version__ >= "2.1.0":
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+        cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
 
+        # start position inside of the query
+        # generally, N goes over kv, while M goes over query_len
         block_start_loc = BLOCK_M * start_m
 
         # initialize offsets
+        # [N]; starts at 0
         offs_n = tl.arange(0, BLOCK_N)
-        offs_d = tl.arange(0, BLOCK_DMODEL)
+        # [D]; starts at 0
+        offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
+        # [M]; starts at current position in query
         offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+        # [M,D]
         off_q = (
             (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)
 
-        q = tl.load(
-            Q + off_q,
-            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
-            other=0.0)
+        dim_mask = tl.where(
+            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
+            0).to(tl.int1)  # [D]
 
-        # # initialize pointer to m and l
-        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
-        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
-        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+        q = tl.load(Q + off_q,
+                    mask=dim_mask[None, :] &
+                    (offs_m[:, None] < cur_batch_query_len),
+                    other=0.0)  # [M,D]
+
+        # initialize pointer to m and l
+        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")  # [M]
+        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # [M]
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
+                       dtype=tl.float32)  # [M,D]
 
+        # compute query against context (no causal mask here)
         for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
             start_n = tl.multiple_of(start_n, BLOCK_N)
             # -- compute qk ----
             bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
                          ((start_n + offs_n) // block_size) * stride_b_loc_s,
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
-                         other=0)
+                         other=0)  # [N]
+            # [D,N]
             off_k = (bn[None, :] * stride_k_cache_bs +
                      cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
+            # [N,D]
             off_v = (
                 bn[:, None] * stride_v_cache_bs +
                 cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
-                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
-                        other=0.0)
+                        mask=dim_mask[:, None] &
+                        ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
+                        other=0.0)  # [D,N]
 
-            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)  # [M,N]
             qk += tl.dot(q, k)
             qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
                           float("-inf"))
             qk *= sm_scale
+            if SLIDING_WINDOW > 0:
+                # (cur_batch_ctx_len + offs_m[:, None]) are the positions of
+                # Q entries in sequence
+                # (start_n + offs_n[None, :]) are the positions of
+                # KV entries in sequence
+                # So the condition makes sure each entry in Q only attends
+                # to KV entries not more than SLIDING_WINDOW away.
+                #
+                # We can't use -inf here, because the
+                # sliding window may lead to the entire row being masked.
+                # This then makes m_ij contain -inf, which causes NaNs in
+                # exp().
+                qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
+                              (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk,
+                              -10000)
 
             # -- compute m_ij, p, l_ij
-            m_ij = tl.max(qk, 1)
-            p = tl.exp(qk - m_ij[:, None])
-            l_ij = tl.sum(p, 1)
+            m_ij = tl.max(qk, 1)  # [M]
+            p = tl.exp(qk - m_ij[:, None])  # [M,N]
+            l_ij = tl.sum(p, 1)  # [M]
             # -- update m_i and l_i
-            m_i_new = tl.maximum(m_i, m_ij)
-            alpha = tl.exp(m_i - m_i_new)
-            beta = tl.exp(m_ij - m_i_new)
-            l_i_new = alpha * l_i + beta * l_ij
+            m_i_new = tl.maximum(m_i, m_ij)  # [M]
+            alpha = tl.exp(m_i - m_i_new)  # [M]
+            beta = tl.exp(m_ij - m_i_new)  # [M]
+            l_i_new = alpha * l_i + beta * l_ij  # [M]
+
             # -- update output accumulator --
             # scale p
             p_scale = beta / l_i_new
@@ -126,8 +162,9 @@ if triton.__version__ >= "2.1.0":
             acc = acc * acc_scale[:, None]
             # update acc
             v = tl.load(V_cache + off_v,
-                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
-                        other=0.0)
+                        mask=dim_mask[None, :] &
+                        ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
+                        other=0.0)  # [N,D]
 
             p = p.to(v.dtype)
             acc += tl.dot(p, v)
@@ -142,23 +179,29 @@ if triton.__version__ >= "2.1.0":
         k_ptrs = K + off_k
         v_ptrs = V + off_v
 
-        block_mask = tl.where(
-            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
+        # block_mask is 0 when we're already past the current query length
+        block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
 
+        # compute query against itself (with causal mask)
         for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
             start_n = tl.multiple_of(start_n, BLOCK_N)
             # -- compute qk ----
             k = tl.load(k_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_kbs,
-                        mask=(start_n + offs_n[None, :]) <
-                        cur_batch_seq_len - cur_batch_ctx_len,
+                        mask=dim_mask[:, None] &
+                        ((start_n + offs_n[None, :]) < cur_batch_query_len),
                         other=0.0)
 
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
             qk += tl.dot(q, k)
             qk *= sm_scale
+            # apply causal mask
             qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                           float("-inf"))
+            if SLIDING_WINDOW > 0:
+                qk = tl.where(
+                    offs_m[:, None] -
+                    (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)
 
             # -- compute m_ij, p, l_ij
             m_ij = tl.max(qk, 1)
@@ -179,8 +222,8 @@ if triton.__version__ >= "2.1.0":
             # update acc
             v = tl.load(v_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_vbs,
-                        mask=(start_n + offs_n[:, None]) <
-                        cur_batch_seq_len - cur_batch_ctx_len,
+                        mask=dim_mask[None, :] &
+                        ((start_n + offs_n[:, None]) < cur_batch_query_len),
                         other=0.0)
 
             p = p.to(v.dtype)
@@ -195,7 +238,8 @@ if triton.__version__ >= "2.1.0":
         out_ptrs = Out + off_o
         tl.store(out_ptrs,
                  acc,
-                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+                 mask=dim_mask[None, :] &
+                 (offs_m[:, None] < cur_batch_query_len))
         return
 
     @triton.jit
@@ -430,7 +474,8 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_bl,
         num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
-        BLOCK_DMODEL: tl.constexpr,
+        BLOCK_DMODEL: tl.constexpr,  # head size
+        BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
         BLOCK_N: tl.constexpr,
     ):
         # attn_bias[]
@@ -451,21 +496,24 @@ if triton.__version__ >= "2.1.0":
 
         # initialize offsets
         offs_n = tl.arange(0, BLOCK_N)
-        offs_d = tl.arange(0, BLOCK_DMODEL)
+        offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
         offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
         off_q = (
             (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
             cur_head * stride_qh + offs_d[None, :] * stride_qd)
 
-        q = tl.load(
-            Q + off_q,
-            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
-            other=0.0)
+        dim_mask = tl.where(
+            tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
+
+        q = tl.load(Q + off_q,
+                    mask=dim_mask[None, :] &
+                    (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
+                    other=0.0)
 
         # # initialize pointer to m and l
         m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
         l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
-        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
 
         alibi_slope = tl.load(Alibi_slopes + cur_head)
         alibi_start_q = tl.arange(
@@ -490,8 +538,9 @@ if triton.__version__ >= "2.1.0":
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
-                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
-                        other=0.0)
+                        mask=dim_mask[:, None] &
+                        ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
+                        other=0.0)  # [D,N]
 
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
             qk += tl.dot(q, k)
@@ -525,7 +574,8 @@ if triton.__version__ >= "2.1.0":
             acc = acc * acc_scale[:, None]
             # update acc
             v = tl.load(V_cache + off_v,
-                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
+                        mask=dim_mask[None, :] &
+                        ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
                         other=0.0)
 
             p = p.to(v.dtype)
@@ -558,8 +608,9 @@ if triton.__version__ >= "2.1.0":
             # -- compute qk ----
             k = tl.load(k_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_kbs,
-                        mask=(start_n + offs_n[None, :]) <
-                        cur_batch_seq_len - cur_batch_ctx_len,
+                        mask=dim_mask[:, None] &
+                        ((start_n + offs_n[None, :]) <
+                         cur_batch_seq_len - cur_batch_ctx_len),
                         other=0.0)
 
             qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -595,8 +646,9 @@ if triton.__version__ >= "2.1.0":
             # update acc
             v = tl.load(v_ptrs +
                         (cur_batch_in_all_start_index + start_n) * stride_vbs,
-                        mask=(start_n + offs_n[:, None]) <
-                        cur_batch_seq_len - cur_batch_ctx_len,
+                        mask=dim_mask[None, :] &
+                        ((start_n + offs_n[:, None]) <
+                         cur_batch_seq_len - cur_batch_ctx_len),
                         other=0.0)
 
             p = p.to(v.dtype)
@@ -614,7 +666,8 @@ if triton.__version__ >= "2.1.0":
         out_ptrs = Out + off_o
         tl.store(out_ptrs,
                  acc,
-                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+                 mask=dim_mask[None, :] &
+                 (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len))
         return
 
     @torch.inference_mode()
@@ -629,14 +682,16 @@ if triton.__version__ >= "2.1.0":
                               b_seq_len,
                               b_ctx_len,
                               max_input_len,
-                              alibi_slopes=None):
+                              alibi_slopes=None,
+                              sliding_window=None):
 
-        cap = torch.cuda.get_device_capability()
+        cap = current_platform.get_device_capability()
         BLOCK = 128 if cap[0] >= 8 else 64
         # shape constraints
         Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
         assert Lq == Lk and Lk == Lv
-        assert Lk in {16, 32, 64, 128}
+        # round up Lk to a power of 2 - this is required for Triton block size
+        Lk_padded = triton.next_power_of_2(Lk)
 
         sm_scale = 1.0 / (Lq**0.5)
         batch, head = b_seq_len.shape[0], q.shape[1]
@@ -644,6 +699,10 @@ if triton.__version__ >= "2.1.0":
 
         grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,
 
+        # 0 means "disable"
+        if sliding_window is None or sliding_window <= 0:
+            sliding_window = 0
+
         num_warps = 8 if Lk <= 64 else 8
         if alibi_slopes is not None:
             _fwd_kernel_alibi[grid](
@@ -659,7 +718,7 @@ if triton.__version__ >= "2.1.0":
                 b_ctx_len,
                 alibi_slopes,
                 v_cache.shape[3],
-                8,
+                k_cache.shape[4],
                 o,
                 b_loc.stride(0),
                 b_loc.stride(1),
@@ -690,6 +749,7 @@ if triton.__version__ >= "2.1.0":
                 num_queries_per_kv=num_queries_per_kv,
                 BLOCK_M=BLOCK,
                 BLOCK_DMODEL=Lk,
+                BLOCK_DMODEL_PADDED=Lk_padded,
                 BLOCK_N=BLOCK,
                 num_warps=num_warps,
                 num_stages=1,
@@ -708,7 +768,7 @@ if triton.__version__ >= "2.1.0":
             b_seq_len,
             b_ctx_len,
             v_cache.shape[3],
-            8,
+            k_cache.shape[4],
             o,
             b_loc.stride(0),
             b_loc.stride(1),
@@ -738,8 +798,10 @@ if triton.__version__ >= "2.1.0":
             num_queries_per_kv=num_queries_per_kv,
             BLOCK_M=BLOCK,
             BLOCK_DMODEL=Lk,
+            BLOCK_DMODEL_PADDED=Lk_padded,
             BLOCK_N=BLOCK,
+            SLIDING_WINDOW=sliding_window,
             num_warps=num_warps,
             num_stages=1,
         )
-        return
+        return

+ 28 - 13
aphrodite/attention/ops/triton_flash_attn.py

@@ -2,16 +2,22 @@
 """
 Fused Attention
 ===============
+
 This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
 (https://tridao.me/publications/flash2/flash2.pdf)
 Credits: OpenAI kernel team, AMD ML Frameworks Triton team
+
 Features supported:
+
 1) Fwd with causal masking
 2) Any sequence lengths without padding (currently fwd kernel only)
 3) Support for different sequence lengths for q and k
 4) Nested tensor API currently does not support dropout or bias.
+
 Not currently supported:
+
 1) Non power of two head dims
+
 """
 
 import torch
@@ -233,6 +239,16 @@ def _attn_fwd_inner(
             num_stages=1,
             num_warps=8,
         ),
+        triton.Config(
+            {
+                "BLOCK_M": 128,
+                "BLOCK_N": 64,
+                "waves_per_eu": 1,
+                "PRE_LOAD_V": False,
+            },
+            num_stages=1,
+            num_warps=4,
+        ),
         triton.Config(
             {
                 "BLOCK_M": 128,
@@ -287,7 +303,7 @@ def _attn_fwd_inner(
             num_warps=4,
         ),
     ],
-    key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
+    key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
 )
 @triton.jit
 def attn_fwd(
@@ -324,8 +340,8 @@ def attn_fwd(
     philox_seed,
     philox_offset_base,
     encoded_softmax,
-    hq,
-    hk,
+    HQ: tl.constexpr,
+    HK: tl.constexpr,
     ACTUAL_BLOCK_DMODEL: tl.constexpr,
     MAX_SEQLENS_Q: tl.constexpr,
     MAX_SEQLENS_K: tl.constexpr,
@@ -397,7 +413,7 @@ def attn_fwd(
             # We still need to write 0s to the result
             # tl.store(O_block_ptr,
             # acc.to(Out.type.element_ty), boundary_check=(0,1))
-            # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
+            # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
             #          + offs_m
             # We store inf to LSE, not -inf because in the bwd pass,
             # we subtract this
@@ -408,11 +424,10 @@ def attn_fwd(
             # TODO: Should dropout and return encoded softmax be handled here?
             return
 
-    is_mqa = hq != hk
-    if is_mqa:  # noqa: SIM108
-        off_h_k = off_h_q % hk
-    else:
-        off_h_k = off_h_q
+    # If MQA / GQA, set the K and V head offsets appropriately.
+    GROUP_SIZE: tl.constexpr = HQ // HK
+    off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
+
     n_extra_tokens = 0
     if seqlen_k < BLOCK_N:
         n_extra_tokens = BLOCK_N - seqlen_k
@@ -464,7 +479,7 @@ def attn_fwd(
         bias_ptr = None
     if ENABLE_DROPOUT:
         batch_philox_offset = philox_offset_base \
-                              + (off_z * hq + off_h_q) \
+                              + (off_z * HQ + off_h_q) \
                               * seqlen_q * seqlen_k
     else:
         batch_philox_offset = 0
@@ -617,7 +632,7 @@ def attn_fwd(
             z = 0.0
             acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
     # write back LSE
-    # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
+    # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
     # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
     # few rows. This is only true for the last M block. For others,
     # overflow_size will be -ve
@@ -777,8 +792,8 @@ class _attention(torch.autograd.Function):
             philox_seed=philox_seed,
             philox_offset_base=philox_offset,
             encoded_softmax=encoded_softmax,
-            hq=nheads_q,
-            hk=nheads_k,
+            HQ=nheads_q,
+            HK=nheads_k,
             ACTUAL_BLOCK_DMODEL=head_size,
             MAX_SEQLENS_Q=max_seqlens_q,
             MAX_SEQLENS_K=max_seqlens_k,

+ 156 - 34
aphrodite/attention/selector.py

@@ -1,12 +1,16 @@
 import enum
+import os
 from functools import lru_cache
-from typing import Type
+from typing import Optional, Type
 
 import torch
 from loguru import logger
 
 from aphrodite.attention.backends.abstract import AttentionBackend
-from aphrodite.common.utils import is_cpu, is_hip
+from aphrodite.common.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
+from aphrodite.platforms import current_platform
+
+APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
 
 
 class _Backend(enum.Enum):
@@ -14,61 +18,179 @@ class _Backend(enum.Enum):
     XFORMERS = enum.auto()
     ROCM_FLASH = enum.auto()
     TORCH_SDPA = enum.auto()
+    OPENVINO = enum.auto()
+    FLASHINFER = enum.auto()
+    PALLAS = enum.auto()
+    IPEX = enum.auto()
 
 
 @lru_cache(maxsize=None)
-def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
-    backend = _which_attn_to_use(dtype)
+def get_attn_backend(
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    sliding_window: Optional[int],
+    dtype: torch.dtype,
+    kv_cache_dtype: Optional[str],
+    block_size: int,
+    is_blocksparse: bool = False,
+) -> Type[AttentionBackend]:
+
+    if is_blocksparse:
+        logger.info("Using BlocksparseFlashAttention backend.")
+        from aphrodite.attention.backends.blocksparse_attn import (
+            BlocksparseFlashAttentionBackend)
+        return BlocksparseFlashAttentionBackend
+    """Determine which attention backend to use and only import
+    the selected backend module.
+    """
+    backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
+                                sliding_window, dtype, kv_cache_dtype,
+                                block_size)
     if backend == _Backend.FLASH_ATTN:
-        logger.info("Using FlashAttention backend.")
-        from aphrodite.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
+        from aphrodite.attention.backends.flash_attn import (  # noqa: F401
+            FlashAttentionBackend)
         return FlashAttentionBackend
-    elif backend == _Backend.XFORMERS:
+    if backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
-        from aphrodite.attention.backends.xformers import XFormersBackend  # noqa: F501
+        from aphrodite.attention.backends.xformers import (  # noqa: F401
+            XFormersBackend)
         return XFormersBackend
     elif backend == _Backend.ROCM_FLASH:
-        logger.info("Using ROCm FlashAttention backend.")
+        logger.info("Using ROCmFlashAttention backend.")
         from aphrodite.attention.backends.rocm_flash_attn import (  # noqa: F401
             ROCmFlashAttentionBackend)
         return ROCmFlashAttentionBackend
     elif backend == _Backend.TORCH_SDPA:
+        assert is_cpu(), RuntimeError(
+            "Torch SDPA backend is only used for CPU devices.")
         logger.info("Using Torch SDPA backend.")
-        from aphrodite.attention.backends.sdpa import TorchSDPABackend
+        from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         return TorchSDPABackend
+    elif backend == _Backend.OPENVINO:
+        logger.info("Using OpenVINO attention backend.")
+        from aphrodite.attention.backends.openvino import (
+            OpenVINOAttentionBackend)
+        return OpenVINOAttentionBackend
+    elif backend == _Backend.IPEX:
+        assert is_xpu(), RuntimeError(
+            "IPEX attention backend is only used for the XPU device.")
+        logger.info("Using IPEX attention backend.")
+        from aphrodite.attention.backends.ipex_attn import IpexAttnBackend
+        return IpexAttnBackend
+    elif backend == _Backend.FLASHINFER:
+        logger.info("Using Flashinfer backend.")
+        from aphrodite.attention.backends.flashinfer import FlashInferBackend
+        return FlashInferBackend
+    elif backend == _Backend.PALLAS:
+        logger.info("Using Pallas backend.")
+        from aphrodite.attention.backends.pallas import PallasAttentionBackend
+        return PallasAttentionBackend
     else:
         raise ValueError("Invalid attention backend.")
 
 
-def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
+def which_attn_to_use(
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    sliding_window: Optional[int],
+    dtype: torch.dtype,
+    kv_cache_dtype: Optional[str],
+    block_size: int,
+) -> _Backend:
     """Returns which flash attention backend to use."""
+
+    # Default case.
+    selected_backend = _Backend.FLASH_ATTN
+
+    # Check the environment variable and override if specified
+    backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
+    if backend_by_env_var is not None:
+        backend_members = _Backend.__members__
+        if backend_by_env_var.upper() not in backend_members:
+            raise ValueError(
+                f"Invalid attention backend '{backend_by_env_var}'. "
+                f"Available backends: {', '.join(backend_members)} ")
+        selected_backend = _Backend[backend_by_env_var.upper()]
     if is_cpu():
+        if selected_backend != _Backend.TORCH_SDPA:
+            logger.info(f"Cannot use {selected_backend} backend on CPU.")
         return _Backend.TORCH_SDPA
 
+    if is_openvino():
+        if selected_backend != _Backend.OPENVINO:
+            logger.info(f"Cannot use {selected_backend} backend on OpenVINO.")
+        return _Backend.OPENVINO
+
+    if is_xpu():
+        if selected_backend != _Backend.IPEX:
+            logger.info(f"Cannot use {selected_backend} backend on XPU.")
+        return _Backend.IPEX
+
+    if is_tpu():
+        if selected_backend != _Backend.PALLAS:
+            logger.info(f"Cannot use {selected_backend} backend on TPU.")
+        return _Backend.PALLAS
+
     if is_hip():
         # AMD GPUs.
-        if torch.cuda.get_device_capability()[0] != 9:
-            # not Instinct series GPUs.
-            logger.info("flash_atten is not supported on NAVI GPUs.")
+        selected_backend = (_Backend.ROCM_FLASH if selected_backend
+                            == _Backend.FLASH_ATTN else selected_backend)
+        if selected_backend == _Backend.ROCM_FLASH:
+            if current_platform.get_device_capability()[0] != 9:
+                # not Instinct series GPUs.
+                logger.info("flash_attn is not supported on NAVI GPUs.")
+        else:
+            logger.info(f"{selected_backend} is not supported in AMD GPUs.")
         return _Backend.ROCM_FLASH
 
-    # NVIDIA GPUs.
-    if torch.cuda.get_device_capability()[0] < 8:
-        # Volta and Turing NVIDIA GPUs.
-        logger.info("Cannot use FlashAttention backend for Volta and Turing "
-                    "GPUs.")
-        return _Backend.XFORMERS
-
-    if dtype not in (torch.float16, torch.bfloat16):
-        logger.info("Cannot use FlashAttention backend for dtype other than "
-                    "torch.float16 or torch.bfloat16.")
-        return _Backend.XFORMERS
-
-    try:
-        import flash_attn  # noqa: F401
-    except ImportError:
-        logger.info(
-            "Cannot use FlashAttention backend because the flash_attn package "
-            "is not found. Please install it for better performance.")
-        return _Backend.XFORMERS
-    return _Backend.FLASH_ATTN
+    # FlashAttn in NVIDIA GPUs.
+    if selected_backend == _Backend.FLASH_ATTN:
+        if current_platform.get_device_capability()[0] < 8:
+            # Volta and Turing NVIDIA GPUs.
+            logger.info(
+                "Cannot use FlashAttention-2 backend for Volta and Turing "
+                "GPUs.")
+            selected_backend = _Backend.XFORMERS
+        elif dtype not in (torch.float16, torch.bfloat16):
+            logger.info(
+                "Cannot use FlashAttention-2 backend for dtype other than "
+                "torch.float16 or torch.bfloat16.")
+            selected_backend = _Backend.XFORMERS
+        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
+            logger.info(
+                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
+            selected_backend = _Backend.XFORMERS
+        elif block_size % 16 != 0:
+            logger.info(
+                "Cannot use FlashAttention-2 backend for block size not "
+                "divisible by 16.")
+            selected_backend = _Backend.XFORMERS
+        elif sliding_window is not None:
+            logger.info(
+                "Cannot use FlashAttention-2 backend due to sliding window.")
+            selected_backend = _Backend.XFORMERS
+
+    # FlashAttn is valid for the model, checking if the package is installed.
+    if selected_backend == _Backend.FLASH_ATTN:
+        try:
+            import aphrodite_flash_attn  # noqa: F401
+
+            from aphrodite.attention.backends.flash_attn import (  # noqa: F401
+                FlashAttentionBackend)
+
+            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
+            if head_size not in supported_sizes:
+                logger.info(
+                    "Cannot use FlashAttention-2 backend for head size "
+                    f"{head_size}")
+                selected_backend = _Backend.XFORMERS
+        except ImportError:
+            logger.info(
+                "Cannot use FlashAttention-2 backend because the "
+                "aphrodite_flash_attn package is not found. "
+                "`pip install aphrodite-flash-attn` for better performance.")
+            selected_backend = _Backend.XFORMERS
+
+    return selected_backend

+ 0 - 41
aphrodite/common/block.py

@@ -3,50 +3,9 @@ from typing import List
 
 from aphrodite.common.utils import Device
 
-_BLANK_TOKEN_ID = -1
 DEFAULT_LAST_ACCESSED_TIME = -1
 
 
-class LogicalTokenBlock:
-    """A block that stores a contiguous chunk of tokens from left to right.
-    Logical blocks are used to represent the states of the corresponding
-    physical blocks in the KV cache.
-    """
-
-    def __init__(
-        self,
-        block_number: int,
-        block_size: int,
-    ) -> None:
-        self.block_number = block_number
-        self.block_size = block_size
-
-        self.token_ids = [_BLANK_TOKEN_ID] * block_size
-        self.num_tokens = 0
-
-    def is_empty(self) -> bool:
-        return self.num_tokens == 0
-
-    def get_num_empty_slots(self) -> int:
-        return self.block_size - self.num_tokens
-
-    def is_full(self) -> bool:
-        return self.num_tokens == self.block_size
-
-    def append_tokens(self, token_ids: List[int]) -> None:
-        assert len(token_ids) <= self.get_num_empty_slots()
-        curr_idx = self.num_tokens
-        self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
-        self.num_tokens += len(token_ids)
-
-    def get_token_ids(self) -> List[int]:
-        return self.token_ids[:self.num_tokens]
-
-    def get_last_token_id(self) -> int:
-        assert self.num_tokens > 0
-        return self.token_ids[self.num_tokens - 1]
-
-
 class PhysicalTokenBlock:
     """Represents the state of a block in the KV cache."""
 

Fișier diff suprimat deoarece este prea mare
+ 487 - 271
aphrodite/common/config.py


+ 167 - 0
aphrodite/common/connections.py

@@ -0,0 +1,167 @@
+from pathlib import Path
+from typing import Mapping, MutableMapping, Optional
+from urllib.parse import urlparse
+
+import aiohttp
+import requests
+
+from aphrodite.version import __version__ as APHRODITE_VERSION
+
+
+class HTTPConnection:
+    """Helper class to send HTTP requests."""
+
+    def __init__(self, *, reuse_client: bool = True) -> None:
+        super().__init__()
+
+        self.reuse_client = reuse_client
+
+        self._sync_client: Optional[requests.Session] = None
+        self._async_client: Optional[aiohttp.ClientSession] = None
+
+    def get_sync_client(self) -> requests.Session:
+        if self._sync_client is None or not self.reuse_client:
+            self._sync_client = requests.Session()
+
+        return self._sync_client
+
+    # NOTE: We intentionally use an async function even though it is not
+    # required, so that the client is only accessible inside async event loop
+    async def get_async_client(self) -> aiohttp.ClientSession:
+        if self._async_client is None or not self.reuse_client:
+            self._async_client = aiohttp.ClientSession()
+
+        return self._async_client
+
+    def _validate_http_url(self, url: str):
+        parsed_url = urlparse(url)
+
+        if parsed_url.scheme not in ("http", "https"):
+            raise ValueError("Invalid HTTP URL: A valid HTTP URL "
+                             "must have scheme 'http' or 'https'.")
+
+    def _headers(self, **extras: str) -> MutableMapping[str, str]:
+        return {"User-Agent": f"Aphrodite/{APHRODITE_VERSION}", **extras}
+
+    def get_response(
+        self,
+        url: str,
+        *,
+        stream: bool = False,
+        timeout: Optional[float] = None,
+        extra_headers: Optional[Mapping[str, str]] = None,
+    ):
+        self._validate_http_url(url)
+
+        client = self.get_sync_client()
+        extra_headers = extra_headers or {}
+
+        return client.get(url,
+                          headers=self._headers(**extra_headers),
+                          stream=stream,
+                          timeout=timeout)
+
+    async def get_async_response(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+        extra_headers: Optional[Mapping[str, str]] = None,
+    ):
+        self._validate_http_url(url)
+
+        client = await self.get_async_client()
+        extra_headers = extra_headers or {}
+
+        return client.get(url,
+                          headers=self._headers(**extra_headers),
+                          timeout=timeout)
+
+    def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.content
+
+    async def async_get_bytes(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> bytes:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.read()
+
+    def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.text
+
+    async def async_get_text(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> str:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.text()
+
+    def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return r.json()
+
+    async def async_get_json(
+        self,
+        url: str,
+        *,
+        timeout: Optional[float] = None,
+    ) -> str:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            return await r.json()
+
+    def download_file(
+        self,
+        url: str,
+        save_path: Path,
+        *,
+        timeout: Optional[float] = None,
+        chunk_size: int = 128,
+    ) -> Path:
+        with self.get_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            with save_path.open("wb") as f:
+                for chunk in r.iter_content(chunk_size):
+                    f.write(chunk)
+
+        return save_path
+
+    async def async_download_file(
+        self,
+        url: str,
+        save_path: Path,
+        *,
+        timeout: Optional[float] = None,
+        chunk_size: int = 128,
+    ) -> Path:
+        async with await self.get_async_response(url, timeout=timeout) as r:
+            r.raise_for_status()
+
+            with save_path.open("wb") as f:
+                async for chunk in r.content.iter_chunked(chunk_size):
+                    f.write(chunk)
+
+        return save_path
+
+
+global_http_connection = HTTPConnection()
+"""The global :class:`HTTPConnection` instance used by Aphrodite."""

+ 71 - 9
aphrodite/common/logger.py

@@ -2,24 +2,25 @@
 Internal logging utility.
 """
 
+import datetime
 import logging
 import os
+import sys
+from functools import partial
+from typing import Optional
 
 from loguru import logger
 from rich.console import Console
 from rich.markup import escape
-from rich.progress import (
-    Progress,
-    TextColumn,
-    BarColumn,
-    TimeRemainingColumn,
-    TaskProgressColumn,
-    MofNCompleteColumn,
-)
+from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
+                           TaskProgressColumn, TextColumn, TimeRemainingColumn)
 
 RICH_CONSOLE = Console()
 LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()
 
+APHRODITE_CONFIGURE_LOGGING = int(os.getenv("APHRODITE_CONFIGURE_LOGGING",
+                                            "1"))
+
 
 def unwrap(wrapped, default=None):
     """Unwrap function for Optionals."""
@@ -130,4 +131,65 @@ def setup_logger():
     logger.log_once = log_once
 
 
-setup_logger()
+setup_logger()
+
+
+def _trace_calls(log_path, root_dir, frame, event, arg=None):
+    if event in ['call', 'return']:
+        # Extract the filename, line number, function name, and the code object
+        filename = frame.f_code.co_filename
+        lineno = frame.f_lineno
+        func_name = frame.f_code.co_name
+        if not filename.startswith(root_dir):
+            # only log the functions in the aphrodite root_dir
+            return
+        # Log every function call or return
+        try:
+            last_frame = frame.f_back
+            if last_frame is not None:
+                last_filename = last_frame.f_code.co_filename
+                last_lineno = last_frame.f_lineno
+                last_func_name = last_frame.f_code.co_name
+            else:
+                # initial frame
+                last_filename = ""
+                last_lineno = 0
+                last_func_name = ""
+            with open(log_path, 'a') as f:
+                if event == 'call':
+                    f.write(f"{datetime.datetime.now()} Call to"
+                            f" {func_name} in {filename}:{lineno}"
+                            f" from {last_func_name} in {last_filename}:"
+                            f"{last_lineno}\n")
+                else:
+                    f.write(f"{datetime.datetime.now()} Return from"
+                            f" {func_name} in {filename}:{lineno}"
+                            f" to {last_func_name} in {last_filename}:"
+                            f"{last_lineno}\n")
+        except NameError:
+            # modules are deleted during shutdown
+            pass
+    return partial(_trace_calls, log_path, root_dir)
+
+
+def enable_trace_function_call(log_file_path: str,
+                               root_dir: Optional[str] = None):
+    """
+    Enable tracing of every function call in code under `root_dir`.
+    This is useful for debugging hangs or crashes.
+    `log_file_path` is the path to the log file.
+    `root_dir` is the root directory of the code to trace. If None, it is the
+    aphrodite root directory.
+
+    Note that this call is thread-level, any threads calling this function
+    will have the trace enabled. Other threads will not be affected.
+    """
+    logger.warning(
+        "APHRODITE_TRACE_FUNCTION is enabled. It will record every"
+        " function executed by Python. This will slow down the code. It "
+        "is suggested to be used for debugging hang or crashes only.")
+    logger.info(f"Trace frame log is saved to {log_file_path}")
+    if root_dir is None:
+        # by default, this is the aphrodite root directory
+        root_dir = os.path.dirname(os.path.dirname(__file__))
+    sys.settrace(partial(_trace_calls, log_file_path, root_dir))

+ 97 - 36
aphrodite/common/outputs.py

@@ -1,16 +1,14 @@
-from typing import List, Optional, Union
 import time
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
 
-from aphrodite.common.sequence import (
-    PromptLogprobs,
-    SampleLogprobs,
-    SequenceGroup,
-    SequenceStatus,
-    RequestMetrics,
-)
+from aphrodite.common.sequence import (PromptLogprobs, RequestMetrics,
+                                       SampleLogprobs, SequenceGroup,
+                                       SequenceStatus)
 from aphrodite.lora.request import LoRARequest
 
 
+@dataclass
 class CompletionOutput:
     """The output data of one completion output of a request.
 
@@ -29,25 +27,14 @@ class CompletionOutput:
         lora_request: The LoRA request that was used to generate the output.
     """
 
-    def __init__(
-        self,
-        index: int,
-        text: str,
-        token_ids: List[int],
-        cumulative_logprob: float,
-        logprobs: Optional[SampleLogprobs],
-        finish_reason: Optional[str] = None,
-        stop_reason: Union[int, str, None] = None,
-        lora_request: Optional[LoRARequest] = None,
-    ) -> None:
-        self.index = index
-        self.text = text
-        self.token_ids = token_ids
-        self.cumulative_logprob = cumulative_logprob
-        self.logprobs = logprobs
-        self.finish_reason = finish_reason
-        self.stop_reason = stop_reason
-        self.lora_request = lora_request
+    index: int
+    text: str
+    token_ids: Tuple[int, ...]
+    cumulative_logprob: Optional[float]
+    logprobs: Optional[SampleLogprobs]
+    finish_reason: Optional[str] = None
+    stop_reason: Union[int, str, None] = None
+    lora_request: Optional[LoRARequest] = None
 
     def finished(self) -> bool:
         return self.finish_reason is not None
@@ -62,8 +49,23 @@ class CompletionOutput:
                 f"stop_reason={self.stop_reason})")
 
 
+@dataclass
+class EmbeddingOutput:
+    """The output data of one completion output of a request.
+    Args:
+        embedding: The embedding vector, which is a list of floats. The
+        length of vector depends on the model as listed in the embedding guide.
+    """
+
+    embedding: List[float]
+
+    def __repr__(self) -> str:
+        return (f"EmbeddingOutput("
+                f"embedding={len(self.embedding)})")
+
+
 class RequestOutput:
-    """The output data of a request to the LLM.
+    """The output data of a completion request to the LLM.
 
     Args:
         request_id: The unique ID of the request.
@@ -79,7 +81,7 @@ class RequestOutput:
     def __init__(
         self,
         request_id: str,
-        prompt: str,
+        prompt: Optional[str],
         prompt_token_ids: List[int],
         prompt_logprobs: Optional[PromptLogprobs],
         outputs: List[CompletionOutput],
@@ -98,6 +100,9 @@ class RequestOutput:
 
     @classmethod
     def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
+        if seq_group.sampling_params is None:
+            raise ValueError(
+                "Sampling parameters are missing for a CompletionRequest.")
         # Get the top-n sequences.
         n = seq_group.sampling_params.n
         seqs = seq_group.get_seqs()
@@ -119,13 +124,14 @@ class RequestOutput:
         include_logprobs = seq_group.sampling_params.logprobs is not None
         text_buffer_length = seq_group.sampling_params.output_text_buffer_length
         outputs = [
-            CompletionOutput(seqs.index(seq),
-                             seq.get_output_text_to_return(text_buffer_length),
-                             seq.get_output_token_ids(),
-                             seq.get_cumulative_logprob(),
-                             seq.output_logprobs if include_logprobs else None,
-                             SequenceStatus.get_finished_reason(seq.status),
-                             seq.stop_reason) for seq in top_n_seqs
+            CompletionOutput(
+                seqs.index(seq),
+                seq.get_output_text_to_return(text_buffer_length),
+                seq.get_output_token_ids(),
+                seq.get_cumulative_logprob() if include_logprobs else None,
+                seq.output_logprobs if include_logprobs else None,
+                SequenceStatus.get_finished_reason(seq.status),
+                seq.stop_reason) for seq in top_n_seqs
         ]
 
         # Every sequence in the sequence group should have the same prompt.
@@ -155,3 +161,58 @@ class RequestOutput:
                 f"finished={self.finished}, "
                 f"metrics={self.metrics}, "
                 f"lora_request={self.lora_request})")
+
+
+class EmbeddingRequestOutput:
+    """
+    The output data of an embedding request to the LLM.
+    Args:
+        request_id (str): A unique identifier for the embedding request.
+        outputs (EmbeddingOutput): The embedding results for the given input.
+        prompt_token_ids (List[int]): A list of token IDs used in the prompt.
+        finished (bool): A flag indicating whether the embedding is completed.
+    """
+
+    def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
+                 prompt_token_ids: List[int], finished: bool):
+        self.request_id = request_id
+        self.prompt_token_ids = prompt_token_ids
+        self.finished = finished
+        self.outputs = outputs
+
+    @classmethod
+    def from_seq_group(cls,
+                       seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
+        if seq_group.embeddings is None:
+            raise ValueError(
+                "Embeddings are missing in seq_group for EmbeddingRequest.")
+        output = EmbeddingOutput(seq_group.embeddings)
+        prompt_token_ids = seq_group.prompt_token_ids
+        finished = seq_group.is_finished()
+
+        return cls(seq_group.request_id, output, prompt_token_ids, finished)
+
+    def __repr__(self):
+        """
+        Returns a string representation of an EmbeddingRequestOutput instance.
+        The representation includes the request_id and the number of outputs,
+        providing a quick overview of the embedding request's results.
+        Returns:
+            str: A string representation of the EmbeddingRequestOutput instance.
+        """
+        return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
+                f"outputs={repr(self.outputs)}, "
+                f"prompt_token_ids={self.prompt_token_ids}, "
+                f"finished={self.finished})")
+
+
+class RequestOutputFactory:
+
+    @staticmethod
+    def create(seq_group):
+        # Determine the type based on a condition, for example:
+        if hasattr(seq_group,
+                   'embeddings') and seq_group.embeddings is not None:
+            return EmbeddingRequestOutput.from_seq_group(seq_group)
+        else:
+            return RequestOutput.from_seq_group(seq_group)

+ 19 - 0
aphrodite/common/pooling_params.py

@@ -0,0 +1,19 @@
+from typing import Any, Optional
+
+
+class PoolingParams:
+    """Pooling parameters for pooling.
+    Attributes:
+        additional_data: Any additional data needed for pooling.
+    """
+
+    def __init__(self, additional_data: Optional[Any] = None):
+        self.additional_data = additional_data
+
+    def clone(self) -> "PoolingParams":
+        """Returns a deep copy of the PoolingParams instance."""
+        return PoolingParams(additional_data=self.additional_data, )
+
+    def __repr__(self) -> str:
+        return (f"PoolingParams("
+                f"additional_metadata={self.additional_data})")

+ 66 - 68
aphrodite/common/sampling_params.py

@@ -1,14 +1,20 @@
 """Sampling parameters for text generation."""
 import copy
+import os
 from enum import IntEnum
 from functools import cached_property
 from typing import Any, Callable, Dict, List, Optional, Union
 
 import torch
-from pydantic import conint
+from loguru import logger
+from pydantic import Field
+from typing_extensions import Annotated
 
 _SAMPLING_EPS = 1e-5
 
+APHRODITE_NO_DEPRECATION_WARNING = bool(
+    int(os.environ.get("APHRODITE_NO_DEPRECATION_WARNING", "0")))
+
 
 class SamplingType(IntEnum):
     GREEDY = 0
@@ -17,9 +23,14 @@ class SamplingType(IntEnum):
     BEAM = 3
 
 
-LogitsProcessorFunc = Callable[[torch.Tensor, List[List[int]]], None]
-"""LogitsProcessorFunc takes a logits tensor and corresponding lists of
-previously generated output tokens, and modifies the logits tensor."""
+LogitsProcessorFunc = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
+                            Callable[[List[int], List[int], torch.Tensor],
+                                     torch.Tensor]]
+"""LogitsProcessor is a function that takes a list
+of previously generated tokens, the logits tensor
+for the next token and, optionally, prompt tokens as a
+first argument, and returns a modified tensor of logits
+to sample from."""
 
 
 class SamplingParams:
@@ -110,11 +121,12 @@ class SamplingParams:
         min_tokens: Minimum number of tokens to generate per output sequence
             before EOS or stop tokens are generated.
         logprobs: Number of log probabilities to return per output token.
-            Note that the implementation follows the OpenAI API: The return
-            result includes the log probabilities on the `logprobs` most likely
-            tokens, as well the chosen tokens. The API will always return the
-            log probability of the sampled token, so there  may be up to
-            `logprobs+1` elements in the response.
+            When set to None, no probability is returned. If set to a non-None
+            value, the result includes the log probabilities of the specified
+            number of most likely tokens, as well as the chosen tokens.
+            Note that the implementation follows the OpenAI API: The API will
+            always return the log probability of the sampled token, so there
+            may be up to `logprobs+1` elements in the response.
         prompt_logprobs: Number of log probabilities to return per prompt token.
         detokenize: Whether to detokenize the output. Defaults to True.
         custom_token_bans: List of token IDs to ban from generating
@@ -122,8 +134,9 @@ class SamplingParams:
             defaults to true.
         spaces_between_special_tokens: Whether to add spaces between special
             tokens in the output. Defaults to True.
-        logits_processors: List of LogitsProcessors to change the probability
-            of token prediction at runtime.
+        logits_processors: List of functions that modify logits based on
+            previously generated tokens, and optionally prompt tokens as
+            a first argument.
         truncate_prompt_tokens: If set to an integer k, will use only the last
             k tokens from the prompt (i.e. left-truncation). Defaults to None
             (i.e. no truncation).
@@ -145,12 +158,6 @@ class SamplingParams:
         eta_cutoff: float = 0.0,
         epsilon_cutoff: float = 0.0,
         typical_p: float = 1.0,
-        mirostat_mode: int = 0,
-        mirostat_tau: float = 0,
-        mirostat_eta: float = 0,
-        dynatemp_min: float = 0,
-        dynatemp_max: float = 0,
-        dynatemp_exponent: float = 1,
         smoothing_factor: float = 0.0,
         smoothing_curve: float = 1.0,
         seed: Optional[int] = None,
@@ -170,7 +177,7 @@ class SamplingParams:
         skip_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
         logits_processors: Optional[List[LogitsProcessorFunc]] = None,
-        truncate_prompt_tokens: Optional[conint(ge=1)] = None,
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -186,15 +193,12 @@ class SamplingParams:
         self.eta_cutoff = eta_cutoff
         self.epsilon_cutoff = epsilon_cutoff
         self.typical_p = typical_p
-        self.mirostat_mode = mirostat_mode
-        self.mirostat_tau = mirostat_tau
-        self.mirostat_eta = mirostat_eta
-        self.dynatemp_min = dynatemp_min
-        self.dynatemp_max = dynatemp_max
-        self.dynatemp_exponent = dynatemp_exponent
         self.smoothing_factor = smoothing_factor
         self.smoothing_curve = smoothing_curve
-        self.seed = seed
+        if seed == -1:
+            self.seed = None
+        else:
+            self.seed = seed
         self.use_beam_search = use_beam_search
         self.length_penalty = length_penalty
         self.early_stopping = early_stopping
@@ -208,8 +212,8 @@ class SamplingParams:
         self.ignore_eos = ignore_eos
         self.max_tokens = max_tokens
         self.min_tokens = min_tokens
-        self.logprobs = logprobs
-        self.prompt_logprobs = prompt_logprobs
+        self.logprobs = 1 if logprobs is True else logprobs
+        self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
         # NOTE: This parameter is only exposed at the engine level for now.
         # It is not exposed in the OpenAI API server, as the OpenAI API does
         # not support returning only a list of token IDs.
@@ -220,6 +224,12 @@ class SamplingParams:
         self.logits_processors = logits_processors or []
         self.include_stop_str_in_output = include_stop_str_in_output
         self.truncate_prompt_tokens = truncate_prompt_tokens
+        # Number of characters to hold back for stop string evaluation
+        # until sequence is finished.
+        if self.stop and not include_stop_str_in_output:
+            self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+        else:
+            self.output_text_buffer_length = 0
 
         self.default_values = {
             "n": 1,
@@ -236,12 +246,6 @@ class SamplingParams:
             "eta_cutoff": 0.0,
             "epsilon_cutoff": 0.0,
             "typical_p": 1.0,
-            "mirostat_mode": 0,
-            "mirostat_tau": 0,
-            "mirostat_eta": 0,
-            "dynatemp_min": 0,
-            "dynatemp_max": 0,
-            "dynatemp_exponent": 1,
             "smoothing_factor": 0.0,
             "smoothing_curve": 1.0,
             "seed": None,
@@ -272,6 +276,12 @@ class SamplingParams:
 
         self._verify_args()
         if self.use_beam_search:
+            if not APHRODITE_NO_DEPRECATION_WARNING:
+                logger.warning(
+                    "[IMPORTANT] We plan to discontinue the support for beam "
+                    "search in the next major release. Set "
+                    "APHRODITE_NO_DEPRECATION_WARNING=1 to "
+                    "suppress this warning.")
             self._verify_beam_search()
         else:
             self._verify_non_beam_search()
@@ -282,8 +292,8 @@ class SamplingParams:
                 self.min_p = 0.0
                 self.top_a = 0.0
                 self._verify_greedy_sampling()
-        # injected by the engine
-        self.eos_token_id = None
+        # eos_token_id is added to this by the engine
+        self.all_stop_token_ids = set(self.stop_token_ids)
 
     def _verify_args(self) -> None:
         if self.n < 1:
@@ -324,32 +334,6 @@ class SamplingParams:
         if not 0.0 <= self.typical_p <= 1.0:
             raise ValueError(
                 f"typical_p must be in (0, 1], got {self.typical_p}.")
-        if not self.dynatemp_min >= 0:
-            raise ValueError(
-                f"dynatemp_min must be non negative, got {self.dynatemp_min}.")
-        if not self.dynatemp_max >= 0:
-            raise ValueError(
-                f"dynatemp_max must be non negative, got {self.dynatemp_max}.")
-        if not self.dynatemp_exponent >= 0:
-            raise ValueError(f"dynatemp_exponent must be non negative, got "
-                             f"{self.dynatemp_exponent}.")
-        if not self.smoothing_factor >= 0:
-            raise ValueError(f"smoothing_factor must be non negative, got "
-                             f"{self.smoothing_factor}.")
-        if not self.smoothing_curve >= 1.0:
-            raise ValueError(f"smoothing_curve must larger than 1, got "
-                             f"{self.smoothing_curve}.")
-        if self.mirostat_mode:
-            if not self.mirostat_mode == 2:
-                raise ValueError(
-                    "Only Mirostat v2 (2) and disabled (0) supported, "
-                    f"got {self.mirostat_mode}")
-            if not self.mirostat_eta >= 0:
-                raise ValueError(
-                    f"mirostat_eta must be positive, got {self.mirostat_eta}")
-            if not self.mirostat_tau >= 0:
-                raise ValueError(
-                    f"mirostat_tau must be positive, got {self.mirostat_tau}")
         if self.max_tokens is not None and self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -412,16 +396,30 @@ class SamplingParams:
             raise ValueError("top_k must be -1 when using greedy sampling.")
 
     def update_from_generation_config(
-            self, generation_config: Dict[str, Any]) -> None:
+            self,
+            generation_config: Dict[str, Any],
+            model_eos_token_id: Optional[int] = None) -> None:
         """Update if there are non-default values from generation_config"""
+
+        if model_eos_token_id is not None:
+            # Add the eos token id into the sampling_params to support
+            # min_tokens processing.
+            self.all_stop_token_ids.add(model_eos_token_id)
+
         # Update eos_token_id for generation
-        if eos_ids := generation_config.get("eos_token_id"):
+        if (eos_ids := generation_config.get("eos_token_id")) is not None:
             # it can be either int or list of int
-            if isinstance(eos_ids, int):
-                eos_ids = [eos_ids]
-            original_stop_token_ids = set(self.stop_token_ids)
-            original_stop_token_ids.update(eos_ids)
-            self.stop_token_ids = list(original_stop_token_ids)
+            eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
+            if model_eos_token_id is not None:
+                # We don't need to include the primary eos_token_id in
+                # stop_token_ids since it's handled separately for stopping
+                # purposes.
+                eos_ids.discard(model_eos_token_id)
+            if eos_ids:
+                self.all_stop_token_ids.update(eos_ids)
+                if not self.ignore_eos:
+                    eos_ids.update(self.stop_token_ids)
+                    self.stop_token_ids = list(eos_ids)
 
     @cached_property
     def sampling_type(self) -> SamplingType:

+ 461 - 182
aphrodite/common/sequence.py

@@ -1,21 +1,30 @@
 """Sequence and its related classes."""
 import copy
 import enum
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Union, TYPE_CHECKING
+import math
+from abc import ABC, abstractmethod
+from array import array
+from collections import defaultdict
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
 
-from aphrodite.common.block import LogicalTokenBlock
+import torch
+
+from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.lora.request import LoRARequest
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 if TYPE_CHECKING:
-    import torch
+    from aphrodite.inputs import LLMInputs
+    from aphrodite.multimodal import MultiModalDataDict
     from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 
 @dataclass
 class Logprob:
     """Infos for supporting OpenAI compatible logprobs and token ranks.
+
     Attributes:
         logprob: The logprob of chosen token
         rank: The vocab rank of chosen token (>=1)
@@ -26,29 +35,28 @@ class Logprob:
     decoded_token: Optional[str] = None
 
 
+# {token_id -> logprob} per each sequence group. None if the corresponding
+# sequence group doesn't require prompt logprob.
 PromptLogprobs = List[Optional[Dict[int, Logprob]]]
+# {token_id -> logprob} for each sequence group.
 SampleLogprobs = List[Dict[int, Logprob]]
 
 
-class SequenceStatus(enum.Enum):
+class SequenceStatus(enum.IntEnum):
     """Status of a sequence."""
-
-    WAITING = enum.auto()
-    RUNNING = enum.auto()
-    SWAPPED = enum.auto()
-    FINISHED_STOPPED = enum.auto()
-    FINISHED_LENGTH_CAPPED = enum.auto()
-    FINISHED_ABORTED = enum.auto()
-    FINISHED_IGNORED = enum.auto()
+    WAITING = 0
+    RUNNING = 1
+    SWAPPED = 2
+    # Note: anything after SWAPPED (2) will be considered
+    # as a finished status.
+    FINISHED_STOPPED = 3
+    FINISHED_LENGTH_CAPPED = 4
+    FINISHED_ABORTED = 5
+    FINISHED_IGNORED = 6
 
     @staticmethod
     def is_finished(status: "SequenceStatus") -> bool:
-        return status in [
-            SequenceStatus.FINISHED_STOPPED,
-            SequenceStatus.FINISHED_LENGTH_CAPPED,
-            SequenceStatus.FINISHED_ABORTED,
-            SequenceStatus.FINISHED_IGNORED,
-        ]
+        return status > SequenceStatus.SWAPPED
 
     @staticmethod
     def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
@@ -77,14 +85,13 @@ class SequenceStage(enum.Enum):
 class RequestMetrics:
     """Metrics associated with a request.
 
-    Args:
+    Attributes:
         arrival_time: The time when the request arrived.
         first_scheduled_time: The time when the request was first scheduled.
         first_token_time: The time when the first token was generated.
         time_in_queue: The time the request spent in the queue.
         finished_time: The time when the request was finished.
     """
-
     arrival_time: float
     last_token_time: float
     first_scheduled_time: Optional[float]
@@ -112,31 +119,76 @@ class SequenceData:
         prompt_token_ids: List[int],
         output_token_ids: Optional[List[int]] = None,
     ) -> None:
-        if output_token_ids is None:
-            output_token_ids = []
+        self._prompt_token_ids = array('l', prompt_token_ids)
+        self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
+        self._output_token_ids = array(
+            'l', output_token_ids if output_token_ids is not None else [])
 
-        self.prompt_token_ids = prompt_token_ids
-        self.output_token_ids = output_token_ids
         self.cumulative_logprob = 0.0
         # The number of tokens that are computed (that run against the model).
         self._num_computed_tokens = 0
         self._stage: SequenceStage = SequenceStage.PREFILL
 
+        self._update_cached_all_tokens()
+
+    def _update_cached_all_tokens(self):
+        self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
+                                                     self._output_token_ids)
+
+    @property
+    def prompt_token_ids(self) -> Tuple[int, ...]:
+        return self._prompt_token_ids_tuple
+
+    @prompt_token_ids.setter
+    def prompt_token_ids(self, new_prompt_token_ids) -> None:
+        self._prompt_token_ids = array('l', new_prompt_token_ids)
+        self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
+        self._update_cached_all_tokens()
+
+    @property
+    def prompt_token_ids_array(self) -> array:
+        return self._prompt_token_ids
+
+    @property
+    def output_token_ids(self) -> Tuple[int, ...]:
+        return tuple(self._output_token_ids)
+
+    @output_token_ids.setter
+    def output_token_ids(self, new_output_token_ids) -> None:
+        self._output_token_ids = array('l', new_output_token_ids)
+        self._update_cached_all_tokens()
+
+    @property
+    def output_token_ids_array(self) -> array:
+        return self._output_token_ids
+
     def append_token_id(self, token_id: int, logprob: float) -> None:
-        self.output_token_ids.append(token_id)
+        self._output_token_ids.append(token_id)
+        self._cached_all_token_ids.append(token_id)
         self.cumulative_logprob += logprob
 
     def get_len(self) -> int:
-        return len(self.output_token_ids) + len(self.prompt_token_ids)
+        return len(self._output_token_ids) + len(self._prompt_token_ids)
 
     def get_prompt_len(self) -> int:
-        return len(self.prompt_token_ids)
+        return len(self._prompt_token_ids)
 
     def get_output_len(self) -> int:
-        return len(self.output_token_ids)
+        return len(self._output_token_ids)
 
     def get_token_ids(self) -> List[int]:
-        return self.prompt_token_ids + self.output_token_ids
+        return self._cached_all_token_ids
+
+    def get_prefix_token_ids(
+            self, num_tokens: int
+    ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
+        """Get prefix tokens, and make the return value hashable"""
+        prompt_length = self.get_prompt_len()
+        if num_tokens > prompt_length:
+            return (self._prompt_token_ids_tuple,
+                    tuple(self._output_token_ids[:num_tokens - prompt_length]))
+        else:
+            return (self._prompt_token_ids_tuple[:num_tokens], None)
 
     def get_num_computed_tokens(self) -> int:
         """Return the number of prefill tokens that are already computed."""
@@ -160,21 +212,21 @@ class SequenceData:
         self._stage = SequenceStage.PREFILL
 
     def get_num_uncomputed_tokens(self) -> int:
-        """Return the number of prefil tokens that are not computed."""
+        """Return the number of prefill tokens that are not computed."""
         # we use `get_len()` which includes prompt_len + output_len instead
         # of prompt_len here. This is because during recompute we need to
         # prefill for both prompt and output.
         return self.get_len() - self.get_num_computed_tokens()
 
     def get_last_token_id(self) -> int:
-        if not self.output_token_ids:
-            return self.prompt_token_ids[-1]
-        return self.output_token_ids[-1]
+        if not self._output_token_ids:
+            return self._prompt_token_ids[-1]
+        return self._output_token_ids[-1]
 
-    def get_prompt_token_ids(self) -> int:
+    def get_prompt_token_ids(self) -> Tuple[int, ...]:
         return self.prompt_token_ids
 
-    def get_output_token_ids(self) -> int:
+    def get_output_token_ids(self) -> Tuple[int, ...]:
         return self.output_token_ids
 
     @property
@@ -183,8 +235,8 @@ class SequenceData:
 
     def __repr__(self) -> str:
         return (f"SequenceData("
-                f"prompt_token_ids={self.prompt_token_ids}, "
-                f"output_token_ids={self.output_token_ids}, "
+                f"prompt_token_ids={self._prompt_token_ids}, "
+                f"output_token_ids={self._output_token_ids}, "
                 f"cumulative_logprob={self.cumulative_logprob})")
 
 
@@ -193,35 +245,33 @@ class Sequence:
 
     Args:
         seq_id: The ID of the sequence.
-        prompt: The prompt of the sequence.
-        prompt_token_ids: The token IDs of the prompt.
+        inputs: The inputs of the sequence.
         block_size: The block size of the sequence. Should be the same as the
             block size used by the block manager and cache engine.
         lora_request: LoRA request.
+        prompt_adapter_request: Prompt adapter request.
     """
 
     def __init__(
-        self,
-        seq_id: int,
-        prompt: str,
-        prompt_token_ids: List[int],
-        block_size: int,
-        eos_token_id: Optional[int] = None,
-        lora_request: Optional[LoRARequest] = None,
+            self,
+            seq_id: int,
+            inputs: "LLMInputs",
+            block_size: int,
+            eos_token_id: Optional[int] = None,
+            lora_request: Optional[LoRARequest] = None,
+            prompt_adapter_request: Optional[PromptAdapterRequest] = None
     ) -> None:
         self.seq_id = seq_id
-        self.prompt = prompt
+        self.inputs = inputs
         self.block_size = block_size
         self.eos_token_id = eos_token_id
         self.lora_request = lora_request
+        self.prompt_adapter_request = prompt_adapter_request
 
-        self.data = SequenceData(prompt_token_ids)
+        self.data = SequenceData(self.prompt_token_ids)
         self.output_logprobs: SampleLogprobs = []
         self.output_text = ""
 
-        self.logical_token_blocks: List[LogicalTokenBlock] = []
-        # Initialize the logical token blocks with the prompt token ids.
-        self._append_tokens_to_blocks(prompt_token_ids)
         self.status = SequenceStatus.WAITING
         self.stop_reason: Union[int, str, None] = None
 
@@ -230,12 +280,32 @@ class Sequence:
         self.read_offset = 0
         # Input + output tokens
         self.tokens: Optional[List[str]] = None
-        self.persistent_data = {}
+
+    @property
+    def n_blocks(self) -> int:
+        return math.ceil(self.get_len() / self.block_size)
+
+    @property
+    def prompt(self) -> Optional[str]:
+        return self.inputs.get("prompt")
+
+    @property
+    def prompt_token_ids(self) -> List[int]:
+        return self.inputs["prompt_token_ids"]
+
+    @property
+    def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
+        return self.inputs.get("multi_modal_data")
 
     @property
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
 
+    @property
+    def prompt_adapter_id(self) -> int:
+        return self.prompt_adapter_request.prompt_adapter_id \
+                        if self.prompt_adapter_request else 0
+
     def get_output_text_to_return(self, buffer_length: int):
         # We return the full output text if the sequence is finished.
         truncate = buffer_length and not self.is_finished()
@@ -243,12 +313,14 @@ class Sequence:
             self.output_text)
 
     def hash_of_block(self, logical_idx: int) -> int:
+        # TODO This can produce incorrect hash when block size > prompt size
+
         # Compute the number of tokens in the sequence
         # TODO: The current hashing function is O(L^2). We should optimize
         # this in the future.
         num_tokens = self.num_hashed_tokens_of_block(logical_idx)
-        return hash(
-            (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
+        hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
+        return hash((hashed_tokens, self.lora_int_id))
 
     def num_hashed_tokens_of_block(self, logical_idx: int):
         return logical_idx * self.block_size + self.block_size
@@ -257,36 +329,12 @@ class Sequence:
         """Reset the sequence states for recomputation."""
         self.data.reset_state_for_recompute()
 
-    def _append_logical_block(self) -> None:
-        block = LogicalTokenBlock(
-            block_number=len(self.logical_token_blocks),
-            block_size=self.block_size,
-        )
-        self.logical_token_blocks.append(block)
-
-    def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
-        cursor = 0
-        while cursor < len(token_ids):
-            if not self.logical_token_blocks:
-                self._append_logical_block()
-
-            last_block = self.logical_token_blocks[-1]
-            if last_block.is_full():
-                self._append_logical_block()
-                last_block = self.logical_token_blocks[-1]
-
-            num_empty_slots = last_block.get_num_empty_slots()
-            last_block.append_tokens(token_ids[cursor:cursor +
-                                               num_empty_slots])
-            cursor += num_empty_slots
-
     def append_token_id(
         self,
         token_id: int,
         logprobs: Dict[int, Logprob],
     ) -> None:
         assert token_id in logprobs
-        self._append_tokens_to_blocks([token_id])
         self.output_logprobs.append(logprobs)
         self.data.append_token_id(token_id, logprobs[token_id].logprob)
 
@@ -302,24 +350,22 @@ class Sequence:
     def get_token_ids(self) -> List[int]:
         return self.data.get_token_ids()
 
-    def get_prompt_token_ids(self) -> List[int]:
+    def get_prompt_token_ids(self) -> Tuple[int, ...]:
         return self.data.get_prompt_token_ids()
 
     def get_last_token_id(self) -> int:
         return self.data.get_last_token_id()
 
-    def get_output_token_ids(self) -> List[int]:
-        return self.data.output_token_ids
+    def get_output_token_ids(self) -> Tuple[int, ...]:
+        return self.data.get_output_token_ids()
 
     def get_cumulative_logprob(self) -> float:
         return self.data.cumulative_logprob
 
-    def get_beam_search_score(
-        self,
-        length_penalty: float = 1.0,
-        seq_len: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-    ) -> float:
+    def get_beam_search_score(self,
+                              length_penalty: float = 1.0,
+                              seq_len: Optional[int] = None,
+                              eos_token_id: Optional[int] = None) -> float:
         """Calculate the beam search score with length penalty.
 
         Adapted from
@@ -345,12 +391,10 @@ class Sequence:
 
     def get_num_new_tokens(self) -> int:
         """Get the number of new tokens to be computed.
-        Args:
-            remainig_token_budget: The remaining token budgets.
+
         Returns:
-            The new number of tokens to be computed. I.e., 1 for decode, prompt
-            size for prefill. If there's not enough remainig_token_budget, it
-            can return the chunked number of new tokens.
+            The new number of tokens to be computed. I.e., 1 for decode, or
+            the remaining prompt size for prefill.
         """
         if self.data.stage == SequenceStage.DECODE:
             return 1
@@ -362,34 +406,7 @@ class Sequence:
     def __repr__(self) -> str:
         return (f"Sequence(seq_id={self.seq_id}, "
                 f"status={self.status.name}, "
-                f"num_blocks={len(self.logical_token_blocks)})")
-
-
-@dataclass
-class SequenceGroupState:
-    """Mutable state tied to a specific sequence group"""
-
-    # torch.Generator used in seeded sampling
-    generator: Optional = None
-
-
-class MultiModalData:
-    """Multi modal request.
-    
-    Args:
-        type: The data type.
-        data: The actual data.
-        The required shape and semantic meaning of it depends on the vision
-        language config of the hosted model. 
-        See `VisionLanguageConfig` in `config.py`.
-    """
-
-    class Type(enum.Enum):
-        IMAGE = enum.auto()
-
-    def __init__(self, type: Type, data: "torch.Tensor"):
-        self.type = type
-        self.data = data
+                f"num_blocks={self.n_blocks}, ")
 
 
 class SequenceGroup:
@@ -401,63 +418,101 @@ class SequenceGroup:
         sampling_params: The sampling parameters used to generate the outputs.
         arrival_time: The arrival time of the request.
         lora_request: LoRA request.
-        multi_modal_requst: Multi modal data for the request.
+        embeddings: The embeddings vectors of the prompt of the sequence group
+            for an embedding model.
+        pooling_params: The pooling parameters used to generate the pooling
+            for an embedding model.
+        encoder_seq: Optional, the single encoder sequence. Should be None
+                     unless you are working with an encoder/decoder model.
+        prompt_adapter_request: Prompt adapter request.
     """
 
     def __init__(
         self,
         request_id: str,
         seqs: List[Sequence],
-        sampling_params: SamplingParams,
         arrival_time: float,
+        sampling_params: Optional[SamplingParams] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        embeddings: Optional[List[float]] = None,
+        pooling_params: Optional[PoolingParams] = None,
+        encoder_seq: Optional[Sequence] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
         self.request_id = request_id
+        self.seqs = seqs
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
         self.sampling_params = sampling_params
-        self.metrics = RequestMetrics(
-            arrival_time=arrival_time,
-            last_token_time=arrival_time,
-            first_scheduled_time=None,
-            first_token_time=None,
-            time_in_queue=None,
-        )
+        self.metrics = RequestMetrics(arrival_time=arrival_time,
+                                      last_token_time=arrival_time,
+                                      first_scheduled_time=None,
+                                      first_token_time=None,
+                                      time_in_queue=None)
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
-        self.state = SequenceGroupState()
-        self.multi_modal_data = multi_modal_data
+        self.embeddings = embeddings
+        self.pooling_params = pooling_params
+        self.prompt_adapter_request = prompt_adapter_request
+        self.encoder_seq = encoder_seq
 
     @property
-    def prompt(self) -> str:
+    def prompt(self) -> Optional[str]:
         # All sequences in the group should have the same prompt.
         # We use the prompt of an arbitrary sequence.
-        return next(iter(self.seqs_dict.values())).prompt
+        return self.seqs[0].prompt
 
     @property
     def prompt_token_ids(self) -> List[int]:
         # All sequences in the group should have the same prompt.
         # We use the prompt of an arbitrary sequence.
-        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
+        return self.seqs[0].prompt_token_ids
+
+    @property
+    def multi_modal_data(self) -> "MultiModalDataDict":
+        # All sequences in the group should have the same multi-modal data.
+        # We use the multi-modal data of an arbitrary sequence.
+        return self.seqs[0].multi_modal_data
 
     @property
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
 
-    def get_last_latency(self, now: float) -> float:
-        """Gets last token latency for Request level timings."""
+    @property
+    def prompt_adapter_id(self) -> int:
+        return self.prompt_adapter_request.prompt_adapter_id \
+                        if self.prompt_adapter_request else 0
+
+    @property
+    def prompt_adapter_num_virtual_tokens(self) -> int:
+        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+                         if self.prompt_adapter_request else 0
+
+    def get_last_latency(self, now: float) -> Optional[float]:
+        """Sets the last token time for Request level timings."""
+        # If still in prefill phase, raise Error.
+        if self.is_prefill():
+            raise ValueError(
+                "seq_group.get_last_latency() should not be called "
+                "if the seq_group is in prefill phase.")
+
+        # Otherwise return token latency.
         latency = now - self.metrics.last_token_time
         self.metrics.last_token_time = now
         return latency
 
     def maybe_set_first_token_time(self, time: float) -> None:
         """Sets the first token time for Request level timings."""
-        if self.metrics.first_token_time is None:
+        # NOTE: in a case where a sequence_group is swapped and
+        #   recomputed, the time between iterations is counted
+        #   in TPOT, rather than recalculating TTFT (since from the )
+        #   POV of the user, there is simply a long generation delay.
+        if (self.metrics.first_token_time is None
+                and self.seqs[0].get_output_len() == 1):
             self.metrics.first_token_time = time
 
     def maybe_set_first_scheduled_time(self, time: float) -> None:
-        """Sets the first scheduled time and time in queue for Request level
-        timings."""
+        """Sets the first scheduled time and time in queue for Request
+        level timings."""
         if self.metrics.first_scheduled_time is None:
             self.metrics.first_scheduled_time = time
             self.metrics.time_in_queue = time - self.metrics.arrival_time
@@ -469,12 +524,13 @@ class SequenceGroup:
     def get_max_num_running_seqs(self) -> int:
         """The maximum number of sequences running in parallel in the remaining
         lifetime of the request."""
-        if self.sampling_params.use_beam_search:
+        if self.sampling_params and self.sampling_params.use_beam_search:
             # For beam search, maximally there will always be `best_of` beam
             # candidates running in the future.
             return self.sampling_params.best_of
         else:
-            if self.sampling_params.best_of > self.num_seqs():
+            if (self.sampling_params
+                    and self.sampling_params.best_of > self.num_seqs()):
                 # At prompt stage, the sequence group is not yet filled up
                 # and only have one sequence running. However, in the
                 # generation stage, we will have `best_of` sequences running.
@@ -487,27 +543,31 @@ class SequenceGroup:
         self,
         status: Optional[SequenceStatus] = None,
     ) -> List[Sequence]:
-        return (list(self.seqs_dict.values()) if status is None else [
-            seq for seq in self.seqs_dict.values() if seq.status == status
-        ])
+        if status is None:
+            return self.seqs
+        return [seq for seq in self.seqs if seq.status == status]
+
+    def is_encoder_decoder(self) -> bool:
+        return self.encoder_seq is not None
+
+    def get_encoder_seq(self) -> Optional[Sequence]:
+        return self.encoder_seq
 
     def get_unfinished_seqs(self) -> List[Sequence]:
-        return [
-            seq for seq in self.seqs_dict.values() if not seq.is_finished()
-        ]
+        return [seq for seq in self.seqs if not seq.is_finished()]
 
     def get_finished_seqs(self) -> List[Sequence]:
-        return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
+        return [seq for seq in self.seqs if seq.is_finished()]
 
     def update_num_computed_tokens(self, num_new_computed_tokens: int):
         """Update number of tokens computed so far."""
-        for seq in self.seqs_dict.values():
+        for seq in self.seqs:
             if not seq.is_finished():
                 seq.data.update_num_computed_tokens(num_new_computed_tokens)
 
     def get_num_uncomputed_tokens(self) -> int:
         num_uncomputed_tokens = 0
-        for seq in self.get_seqs():
+        for seq in self.seqs:
             if not seq.is_finished():
                 num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
         return num_uncomputed_tokens
@@ -516,7 +576,7 @@ class SequenceGroup:
         # Optimization. We don't need to call get_seqs if we don't need to
         # filter by states.
         if status is None:
-            return len(self.seqs_dict)
+            return len(self.seqs)
 
         return len(self.get_seqs(status))
 
@@ -535,23 +595,25 @@ class SequenceGroup:
         if seq.seq_id in self.seqs_dict:
             raise ValueError(f"Sequence {seq.seq_id} already exists.")
         self.seqs_dict[seq.seq_id] = seq
+        self.seqs.append(seq)
 
     def remove(self, seq_id: int) -> None:
-        if seq_id not in self.seqs_dict:
+        seq = self.seqs_dict.pop(seq_id, None)
+        if seq is None:
             raise ValueError(f"Sequence {seq_id} not found.")
-        del self.seqs_dict[seq_id]
+        self.seqs.remove(seq)
 
     def is_finished(self) -> bool:
-        return all(seq.is_finished() for seq in self.get_seqs())
+        return all(seq.is_finished() for seq in self.seqs)
 
     def is_prefill(self) -> bool:
-        # Every sequences should be in the same stage.
-        return self.get_seqs()[0].is_prefill()
+        # Every sequence should be in the same stage.
+        return self.seqs[0].is_prefill()
 
     def __repr__(self) -> str:
         return (f"SequenceGroup(request_id={self.request_id}, "
                 f"sampling_params={self.sampling_params}, "
-                f"num_seqs={len(self.seqs_dict)})")
+                f"num_seqs={len(self.seqs)})")
 
 
 class SequenceGroupMetadata:
@@ -564,12 +626,25 @@ class SequenceGroupMetadata:
         sampling_params: The sampling parameters used to generate the outputs.
         block_tables: The block tables. (Seq id -> list of physical block
             numbers)
+        do_sample: True if sampling is required. Sampling is not required when
+            e.g., prefill is chunked, and the current iteration only computes
+            query tokens for prefill, we don't need sampling.
         token_chunk_size: The number of tokens to be processed (per sequence).
             None if chunking is not required.
-        state: Internal state tied to this sequence group.
         lora_request: LoRA request.
-        multi_modal_data: Multi modal data for the request.
-        persistent_data: The persistent data of the sequence group.
+        computed_block_nums: The block numbers that are already computed,
+            used in prefix caching.
+        multi_modal_data: Multi modal data.
+        encoder_seq_data: Optional sequence data for encoder prompt
+                          (SequenceGroup.encoder_seq). Should be None 
+                          unless you are working with an encoder/decoder
+                          model.
+        cross_block_table: Optional cross-attention block table associated
+                           with the encoder prompt
+                           (SequenceGroup.encoder_seq). Should be None
+                           unless you are working with an encoder/decoder
+                           model.
+        prompt_adapter_request: Prompt Adapter request.
     """
 
     def __init__(
@@ -579,24 +654,36 @@ class SequenceGroupMetadata:
         seq_data: Dict[int, SequenceData],
         sampling_params: SamplingParams,
         block_tables: Dict[int, List[int]],
-        persistent_data: Dict[int, dict],
+        do_sample: bool = True,
+        pooling_params: Optional[PoolingParams] = None,
         token_chunk_size: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
         computed_block_nums: Optional[List[int]] = None,
-        state: Optional[SequenceGroupState] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        multi_modal_data: Optional["MultiModalDataDict"] = None,
+        encoder_seq_data: Optional[SequenceData] = None,
+        cross_block_table: Optional[List[int]] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
         self.request_id = request_id
         self.is_prompt = is_prompt
         self.seq_data = seq_data
         self.sampling_params = sampling_params
         self.block_tables = block_tables
-        self.persistent_data = persistent_data
+        self.pooling_params = pooling_params
         self.lora_request = lora_request
+        self.prompt_adapter_request = prompt_adapter_request
         self.computed_block_nums = computed_block_nums
-        self.state = SequenceGroupState() if state is None else state
         self.multi_modal_data = multi_modal_data
+        self.encoder_seq_data = encoder_seq_data
+        self.cross_block_table = cross_block_table
         self._token_chunk_size = token_chunk_size
+        self.do_sample = do_sample
+
+        # The number of speculative tokens adopted in this request.
+        # None means specuative decoding is not used.
+        # Zero means speculative decoding is disabled for some reasons.
+        # TODO: We should maintain this states out of the sequence group.
+        self.num_speculative_tokens = None
 
         if self._token_chunk_size is None:
             if is_prompt:
@@ -608,9 +695,20 @@ class SequenceGroupMetadata:
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
 
+    @property
+    def prompt_adapter_id(self) -> int:
+        return self.prompt_adapter_request.prompt_adapter_id \
+                        if self.prompt_adapter_request else 0
+
+    @property
+    def prompt_adapter_num_virtual_tokens(self) -> int:
+        return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
+                        if self.prompt_adapter_request else 0
+
     @property
     def token_chunk_size(self) -> int:
         """Return the number of tokens to be processed (chunk size)."""
+        assert self._token_chunk_size is not None
         return self._token_chunk_size
 
 
@@ -623,7 +721,6 @@ class SequenceOutput:
         output_token: The output token ID.
         logprobs: The logprobs of the output token.
             (Token id -> logP(x_i+1 | x_0, ..., x_i))
-        persistent_data: The persistent data of the sequence.
     """
 
     def __init__(
@@ -631,18 +728,15 @@ class SequenceOutput:
         parent_seq_id: int,
         output_token: int,
         logprobs: Dict[int, Logprob],
-        persistent_data: dict,
     ) -> None:
         self.parent_seq_id = parent_seq_id
         self.output_token = output_token
         self.logprobs = logprobs
-        self.persistent_data = persistent_data
 
     def __repr__(self) -> str:
         return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
                 f"output_token={self.output_token}, "
-                f"logprobs={self.logprobs}, "
-                f"persistent_data={self.persistent_data})")
+                f"logprobs={self.logprobs})")
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, SequenceOutput):
@@ -653,8 +747,20 @@ class SequenceOutput:
         return equal and log_probs_equal
 
 
-class SequenceGroupOutput:
-    """The model output associated with a sequence group."""
+class SequenceGroupOutput(ABC):
+    """The base class for model outputs associated with a sequence group."""
+
+    @abstractmethod
+    def __repr__(self) -> str:
+        pass
+
+    @abstractmethod
+    def __eq__(self, other: object) -> bool:
+        pass
+
+
+class CompletionSequenceGroupOutput(SequenceGroupOutput):
+    """The model output associated with a completion sequence group."""
 
     def __init__(
         self,
@@ -662,39 +768,93 @@ class SequenceGroupOutput:
         prompt_logprobs: Optional[PromptLogprobs],
     ) -> None:
         self.samples = samples
+        # Prompt logprob for each prompt query token.
         self.prompt_logprobs = prompt_logprobs
 
     def __repr__(self) -> str:
-        return (f"SequenceGroupOutput(samples={self.samples}, "
+        return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
                 f"prompt_logprobs={self.prompt_logprobs})")
 
     def __eq__(self, other: object) -> bool:
-        if not isinstance(other, SequenceGroupOutput):
+        if not isinstance(other, CompletionSequenceGroupOutput):
             raise NotImplementedError()
         return (self.samples == other.samples
                 and self.prompt_logprobs == other.prompt_logprobs)
 
 
+class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
+    """The model output associated with an embedding sequence group."""
+
+    def __init__(
+        self,
+        embeddings: List[float],
+    ) -> None:
+        self.embeddings = embeddings
+
+    def __repr__(self) -> str:
+        return (f"EmbeddingSequenceGroupOutput("
+                f"embeddings_shape={len(self.embeddings)})")
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, EmbeddingSequenceGroupOutput):
+            raise NotImplementedError()
+        return self.embeddings == other.embeddings
+
+
+@dataclass
+class IntermediateTensors:
+    """For all pipeline stages except the last, we need to return the hidden
+    states and residuals to be sent to the next stage. This data structure
+    contains the hidden states and residuals for a request.
+    """
+
+    tensors: Dict[str, torch.Tensor]
+
+    def __getitem__(self, key: Union[str, slice]):
+        if isinstance(key, str):
+            return self.tensors[key]
+        elif isinstance(key, slice):
+            return self.__class__({k: v[key] for k, v in self.tensors.items()})
+
+    def __setitem__(self, key: str, value):
+        self.tensors[key] = value
+
+    def __len__(self):
+        return len(self.tensors)
+
+    def __eq__(self, other: object):
+        return isinstance(other, self.__class__) and self
+
+    def __repr__(self) -> str:
+        return f"IntermediateTensors(tensors={self.tensors})"
+
+
 @dataclass
 class SamplerOutput:
     """For each sequence group, we generate a list of SequenceOutput object,
     each of which contains one possible candidate for the next token.
 
-    This datastructure implements methods so it can be used like a list, but
+    This data structure implements methods, so it can be used like a list, but
     also has optional fields for device tensors.
     """
 
-    outputs: List[SequenceGroupOutput]
+    outputs: List[CompletionSequenceGroupOutput]
 
     # On-device tensor containing probabilities of each token.
-    sampled_token_probs: Optional["torch.Tensor"] = None
+    sampled_token_probs: Optional[torch.Tensor] = None
+
+    # On-device tensor containing the logprobs of each token.
+    logprobs: Optional["torch.Tensor"] = None
 
     # On-device tensor containing the sampled token ids.
-    sampled_token_ids: Optional["torch.Tensor"] = None
+    sampled_token_ids: Optional[torch.Tensor] = None
 
     # Spec decode metrics populated by workers.
     spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
 
+    # Optional last hidden states from the model.
+    hidden_states: Optional[torch.Tensor] = None
+
     def __getitem__(self, idx: int):
         return self.outputs[idx]
 
@@ -705,8 +865,8 @@ class SamplerOutput:
         return len(self.outputs)
 
     def __eq__(self, other: object):
-        return (isinstance(other, self.__class__)
-                and self.outputs == other.outputs)
+        return isinstance(other,
+                          self.__class__) and self.outputs == other.outputs
 
     def __repr__(self) -> str:
         """Show the shape of a tensor instead of its values to reduce noise.
@@ -720,3 +880,122 @@ class SamplerOutput:
             f"sampled_token_probs={sampled_token_probs_repr}, "
             f"sampled_token_ids={sampled_token_ids_repr}, "
             f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
+
+
+@dataclass
+class PoolerOutput:
+    """The output from a pooling operation in the embedding model."""
+    outputs: List[EmbeddingSequenceGroupOutput]
+
+    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
+    def __getitem__(self, idx: int):
+        return self.outputs[idx]
+
+    def __setitem__(self, idx: int, value):
+        self.outputs[idx] = value
+
+    def __len__(self):
+        return len(self.outputs)
+
+    def __eq__(self, other: object):
+        return isinstance(other,
+                          self.__class__) and self.outputs == other.outputs
+
+
+def get_all_seq_ids(
+        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
+    """Given a list of SequenceGroupMetadata, create a list of all
+    sequence ids.
+    """
+    return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
+
+
+def get_all_seq_ids_and_request_ids(
+    seq_group_metadata_list: List[SequenceGroupMetadata]
+) -> Tuple[List[int], Dict[str, Set[int]]]:
+    """Given a list of SequenceGroupMetadata, create a list of all
+    sequence ids.
+    """
+    seq_ids: List[int] = []
+    request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
+    for sg in seq_group_metadata_list:
+        for seq_id in sg.seq_data:
+            seq_ids.append(seq_id)
+            request_id_seq_ids_mapping[sg.request_id].add(seq_id)
+    return seq_ids, request_id_seq_ids_mapping
+
+
+class HiddenStates:
+    """Hidden states corresponding to in-progress sequences.
+    Used in speculative decoding to pass hidden states from
+    the target model to the proposer model in the subsequent step.
+
+    seq_ids are the sequence ids of each entry of the batch
+    dimension of the hidden_states tensor"""
+
+    def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
+                 hidden_states: torch.Tensor):
+        assert len(seq_group_metadata_list) == len(hidden_states)
+        self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
+        self.hidden_states: torch.Tensor = hidden_states
+
+    def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
+               hidden_states: torch.Tensor) -> None:
+        """Update hidden states from target model invocation."""
+        assert len(seq_group_metadata_list) == len(hidden_states)
+        self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
+        self.hidden_states = torch.cat([self.hidden_states, hidden_states])
+
+    def prune(self,
+              seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
+        """Prune to provided list of sequence ids."""
+        seq_ids = get_all_seq_ids(seq_group_metadata_list)
+        if seq_ids != self.seq_ids:
+            # Batch contents changed - prune removed sequences.
+            index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
+            self.hidden_states = self.hidden_states[index]
+            self.seq_ids = seq_ids
+
+
+@dataclass
+class ExecuteModelRequest:
+    """The model execution request, containing CPU metadata only. The LLM
+    engine should create an instance of this class for each request batch."""
+    # The sequence group metadata list.
+    seq_group_metadata_list: List[SequenceGroupMetadata]
+    # Blocks to swap in. List of CPU -> GPU block number.
+    blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
+    # Blocks to swap out. List of GPU -> CPU block number.
+    blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
+    # Blocks to copy. Source to dest block.
+    blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
+    # Virtual engine ID for pipeline parallel.
+    virtual_engine: int = 0
+    # The number of slots for lookahead decoding.
+    num_lookahead_slots: int = 0
+    # The number of requests in the running queue.
+    running_queue_size: int = 0
+    # Optional hidden states from prior step.
+    previous_hidden_states: Optional[HiddenStates] = None
+    # The number of forward steps to run.
+    num_steps: int = 1
+    # Finished request ids since last step.
+    finished_requests_ids: List[str] = field(default_factory=list)
+
+    def clone(
+        self, seq_group_metadata_list: List[SequenceGroupMetadata]
+    ) -> "ExecuteModelRequest":
+        """Clone the request with a new sequence group metadata list."""
+        return ExecuteModelRequest(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=self.blocks_to_swap_in.copy(),
+            blocks_to_swap_out=self.blocks_to_swap_out.copy(),
+            blocks_to_copy=self.blocks_to_copy.copy(),
+            virtual_engine=self.virtual_engine,
+            num_lookahead_slots=self.num_lookahead_slots,
+            running_queue_size=self.running_queue_size,
+            previous_hidden_states=self.previous_hidden_states,
+            num_steps=self.num_steps,
+            finished_requests_ids=self.finished_requests_ids,
+        )

+ 636 - 97
aphrodite/common/utils.py

@@ -1,30 +1,64 @@
+import argparse
 import asyncio
+import datetime
 import enum
 import gc
 import os
 import socket
 import subprocess
+import sys
+import tempfile
+import threading
 import uuid
-from collections import OrderedDict, defaultdict
-from functools import lru_cache, partial
+import warnings
+from collections import defaultdict
+from functools import lru_cache, partial, wraps
 from platform import uname
 from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
-                    Hashable, List, Optional, Tuple, TypeVar, Union)
+                    Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
+                    Union, overload)
 
+import numpy as np
+import numpy.typing as npt
 import psutil
 import torch
+import torch.types
 from loguru import logger
-from packaging.version import Version, parse
+from typing_extensions import ParamSpec
 
-T = TypeVar("T")
+from aphrodite import _custom_ops as ops
+from aphrodite.common.logger import enable_trace_function_call
 
 STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.half,
     "bfloat16": torch.bfloat16,
     "float": torch.float,
     "fp8": torch.uint8,
+    "fp8_e4m3": torch.uint8,
+    "fp8_e5m2": torch.uint8,
+}
+
+TORCH_DTYPE_TO_NUMPY_DTYPE = {
+    torch.float16: np.float16,
+    torch.float32: np.float32,
+    torch.float64: np.float64,
+    torch.uint8: np.uint8,
+    torch.int32: np.int32,
+    torch.int64: np.int64,
 }
 
+P = ParamSpec('P')
+K = TypeVar("K")
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+class _Sentinel:
+    ...
+
+
+ALL_PINNED_SENTINEL = _Sentinel()
+
 
 class Device(enum.Enum):
     GPU = enum.auto()
@@ -48,7 +82,8 @@ class Counter:
 class LRUCache(Generic[T]):
 
     def __init__(self, capacity: int):
-        self.cache = OrderedDict[Hashable, T]()
+        self.cache: OrderedDict[Hashable, T] = OrderedDict()
+        self.pinned_items: Set[Hashable] = set()
         self.capacity = capacity
 
     def __contains__(self, key: Hashable) -> bool:
@@ -58,7 +93,9 @@ class LRUCache(Generic[T]):
         return len(self.cache)
 
     def __getitem__(self, key: Hashable) -> T:
-        return self.get(key)
+        value = self.cache[key]  # Raise KeyError if not exists
+        self.cache.move_to_end(key)
+        return value
 
     def __setitem__(self, key: Hashable, value: T) -> None:
         self.put(key, value)
@@ -72,6 +109,7 @@ class LRUCache(Generic[T]):
     def get(self,
             key: Hashable,
             default_value: Optional[T] = None) -> Optional[T]:
+        value: Optional[T]
         if key in self.cache:
             value = self.cache[key]
             self.cache.move_to_end(key)
@@ -84,29 +122,56 @@ class LRUCache(Generic[T]):
         self.cache.move_to_end(key)
         self._remove_old_if_needed()
 
-    def _on_remove(self, key: Hashable, value: T):
+    def pin(self, key: Hashable) -> None:
+        """
+        Pins a key in the cache preventing it from being
+        evicted in the LRU order.
+        """
+        if key not in self.cache:
+            raise ValueError(f"Cannot pin key: {key} not in cache.")
+        self.pinned_items.add(key)
+
+    def _unpin(self, key: Hashable) -> None:
+        self.pinned_items.remove(key)
+
+    def _on_remove(self, key: Hashable, value: Optional[T]):
         pass
 
-    def remove_oldest(self):
+    def remove_oldest(self, remove_pinned=False):
         if not self.cache:
             return
-        key, value = self.cache.popitem(last=False)
-        self._on_remove(key, value)
+
+        if not remove_pinned:
+            # pop the oldest item in the cache that is not pinned
+            lru_key = next(
+                (key for key in self.cache if key not in self.pinned_items),
+                ALL_PINNED_SENTINEL)
+            if lru_key is ALL_PINNED_SENTINEL:
+                raise RuntimeError("All items are pinned, "
+                                   "cannot remove oldest from the cache.")
+        else:
+            lru_key = next(iter(self.cache))
+        self.pop(lru_key)
 
     def _remove_old_if_needed(self) -> None:
         while len(self.cache) > self.capacity:
             self.remove_oldest()
 
-    def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
+    def pop(self,
+            key: Hashable,
+            default_value: Optional[T] = None) -> Optional[T]:
         run_on_remove = key in self.cache
-        value = self.cache.pop(key, default_value)
+        value: Optional[T] = self.cache.pop(key, default_value)
+        # remove from pinned items
+        if key in self.pinned_items:
+            self._unpin(key)
         if run_on_remove:
             self._on_remove(key, value)
         return value
 
     def clear(self):
         while len(self.cache) > 0:
-            self.remove_oldest()
+            self.remove_oldest(remove_pinned=True)
         self.cache.clear()
 
 
@@ -123,6 +188,15 @@ def is_cpu() -> bool:
         return False
 
 
+@lru_cache(maxsize=None)
+def is_openvino() -> bool:
+    from importlib.metadata import PackageNotFoundError, version
+    try:
+        return "openvino" in version("aphrodite-engine")
+    except PackageNotFoundError:
+        return False
+
+
 @lru_cache(maxsize=None)
 def is_neuron() -> bool:
     try:
@@ -132,15 +206,40 @@ def is_neuron() -> bool:
     return transformers_neuronx is not None
 
 
+@lru_cache(maxsize=None)
+def is_tpu() -> bool:
+    try:
+        import libtpu
+    except ImportError:
+        libtpu = None
+    return libtpu is not None
+
+
+@lru_cache(maxsize=None)
+def is_xpu() -> bool:
+    from importlib.metadata import version
+    is_xpu_flag = "xpu" in version("aphrodite-engine")
+    # aphrodite is not build with xpu
+    if not is_xpu_flag:
+        return False
+    try:
+        import intel_extension_for_pytorch as ipex  # noqa: F401
+        _import_ipex = True
+    except ImportError as e:
+        logger.warning(f"Import Error for IPEX: {e.msg}")
+        _import_ipex = False
+    # ipex dependency is not ready
+    if not _import_ipex:
+        logger.warning("not found ipex lib")
+        return False
+    return hasattr(torch, "xpu") and torch.xpu.is_available()
+
+
 @lru_cache(maxsize=None)
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
-    # NOTE: This import statement should be executed lazily since
-    # the Neuron-X backend does not have the `cuda_utils` module.
-    from aphrodite._C import cuda_utils
-
     max_shared_mem = (
-        cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
+        ops.get_max_shared_memory_per_block_device_attribute(gpu))
     # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
     # will fail
     assert max_shared_mem > 0, "max_shared_mem can not be zero"
@@ -156,13 +255,25 @@ def random_uuid() -> str:
     return str(uuid.uuid4().hex)
 
 
+@lru_cache(maxsize=None)
+def get_aphrodite_instance_id():
+    """
+    If the environment variable APHRODITE_INSTANCE_ID is set, return it.
+    Otherwise, return a random UUID.
+    Instance id represents an instance of the Aphrodite. All processes in the
+    same instance should have the same instance id.
+    """
+    return os.environ.get("APHRODITE_INSTANCE_ID",
+                          f"aphrodite-instance-{random_uuid()}")
+
+
 @lru_cache(maxsize=None)
 def in_wsl() -> bool:
     # Reference: https://github.com/microsoft/WSL/issues/4071
     return "microsoft" in " ".join(uname()).lower()
 
 
-def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
+def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
     """Take a blocking function, and run it on in an executor thread.
 
     This function prevents the blocking function from blocking the
@@ -170,7 +281,7 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
     The code in this function needs to be thread safe.
     """
 
-    def _async_wrapper(*args, **kwargs) -> asyncio.Future:
+    def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
         loop = asyncio.get_event_loop()
         p_func = partial(func, *args, **kwargs)
         return loop.run_in_executor(executor=None, func=p_func)
@@ -178,16 +289,22 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
     return _async_wrapper
 
 
+class ProducerFinished:
+    pass
+
+
 def merge_async_iterators(
         *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
     """Merge multiple asynchronous iterators into a single iterator.
+
     This method handle the case where some iterators finish before others.
     When it yields, it yields a tuple (i, item) where i is the index of the
     iterator that yields the item.
     """
-    queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
+    queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
+                               Exception]] = asyncio.Queue()
 
-    finished = [False] * len(iterators)
+    producers = len(iterators)
 
     async def producer(i: int, iterator: AsyncIterator[T]):
         try:
@@ -195,7 +312,8 @@ def merge_async_iterators(
                 await queue.put((i, item))
         except Exception as e:
             await queue.put(e)
-        finished[i] = True
+        # Signal to the consumer that we've finished
+        await queue.put(ProducerFinished())
 
     _tasks = [
         asyncio.create_task(producer(i, iterator))
@@ -203,17 +321,27 @@ def merge_async_iterators(
     ]
 
     async def consumer():
+        remaining = producers
         try:
-            while not all(finished) or not queue.empty():
+            while remaining or not queue.empty():
+                # we think there is a race condition here
                 item = await queue.get()
+
+                if isinstance(item, ProducerFinished):
+                    # Signal that a producer finished- not a real item
+                    remaining -= 1
+                    continue
+
                 if isinstance(item, Exception):
                     raise item
                 yield item
         except (Exception, asyncio.CancelledError) as e:
             for task in _tasks:
-                # NOTE: Pass the error msg in cancel()
-                # when only Python 3.9+ is supported.
-                task.cancel()
+                if sys.version_info >= (3, 9):
+                    # msg parameter only supported in Python 3.9+
+                    task.cancel(e)
+                else:
+                    task.cancel()
             raise e
         await asyncio.gather(*_tasks)
 
@@ -221,16 +349,35 @@ def merge_async_iterators(
 
 
 def get_ip() -> str:
+    host_ip = os.environ.get("HOST_IP")
+    if host_ip:
+        return host_ip
+
+    # IP is not set, try to get it from the network interface
+
     # try ipv4
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
     try:
         s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
         return s.getsockname()[0]
-    except OSError:
-        # try ipv6
+    except Exception:
+        pass
+
+    # try ipv6
+    try:
         s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
-        s.connect(("dns.google", 80))
+        # Google's public DNS server, see
+        # https://developers.google.com/speed/public-dns/docs/using#addresses
+        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
         return s.getsockname()[0]
+    except Exception:
+        pass
+
+    warnings.warn(
+        "Failed to get the IP address, using 0.0.0.0 by default."
+        "The value can be set by the environment variable HOST_IP.",
+        stacklevel=2)
+    return "0.0.0.0"
 
 
 def get_distributed_init_method(ip: str, port: int) -> str:
@@ -239,26 +386,47 @@ def get_distributed_init_method(ip: str, port: int) -> str:
     return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
 
 
-def get_open_port() -> int:
+def get_open_port(port: Optional[int] = None) -> int:
+    if port is None:
+        # Default behavior here is to return a port for multi-gpu communication
+        port = int(os.getenv("APHRODITE_PORT", 2242))
+    if port is not None:
+        while True:
+            try:
+                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                    s.bind(("", port))
+                    return port
+            except OSError:
+                port += 1  # Increment port number if already in use
+                logger.info(f"Port {port - 1} is already in use, trying port "
+                            f"{port}")
     # try ipv4
     try:
         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             s.bind(("", 0))
             return s.getsockname()[1]
     except OSError:
         # try ipv6
         with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
+            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
             s.bind(("", 0))
             return s.getsockname()[1]
 
 
-def set_cuda_visible_devices(device_ids: List[int]) -> None:
-    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
+def update_environment_variables(envs: Dict[str, str]):
+    for k, v in envs.items():
+        if k in os.environ and os.environ[k] != v:
+            logger.warning(f"Overwriting environment variable {k} "
+                           f"from '{os.environ[k]}' to '{v}'")
+        os.environ[k] = v
 
 
-def chunk_list(lst, chunk_size):
+def chunk_list(lst: List[T], chunk_size: int):
     """Yield successive chunk_size chunks from lst."""
-    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
+    for i in range(0, len(lst), chunk_size):
+        yield lst[i:i + chunk_size]
 
 
 def cdiv(a: int, b: int) -> int:
@@ -266,62 +434,29 @@ def cdiv(a: int, b: int) -> int:
     return -(a // -b)
 
 
-@lru_cache(maxsize=None)
-def get_nvcc_cuda_version() -> Optional[Version]:
-    cuda_home = os.environ.get('CUDA_HOME')
-    if not cuda_home:
-        cuda_home = '/usr/local/cuda'
-        if os.path.isfile(cuda_home + '/bin/nvcc'):
-            logger.info(
-                f'CUDA_HOME is not found in the environment. Using {cuda_home} '
-                'as CUDA_HOME.')
-        else:
-            logger.warning(
-                f'Not found nvcc in {cuda_home}. Skip cuda version check!')
-            return None
-    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
-                                          universal_newlines=True)
-    output = nvcc_output.split()
-    release_idx = output.index("release") + 1
-    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
-    return nvcc_cuda_version
-
-
 def _generate_random_fp8(
-    tensor: torch.tensor,
+    tensor: torch.Tensor,
     low: float,
     high: float,
 ) -> None:
     # NOTE: Due to NaN and Inf representation for fp8 data type,
-    # we may get Inf or NaN if we directly use torch.randint
+    # it may occur Inf or NaN if we directly use torch.randint
     # to generate random data for fp8 data.
     # For example, s.11111.00 in fp8e5m2 format represents Inf.
     #     | E4M3        | E5M2
     #-----|-------------|-------------------
     # Inf | N/A         | s.11111.00
     # NaN | s.1111.111  | s.11111.{01,10,11}
-    from aphrodite._C import cache_ops
+    from aphrodite import _custom_ops as ops
     tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
     tensor_tmp.uniform_(low, high)
-    cache_ops.convert_fp8(tensor_tmp, tensor)
+    ops.convert_fp8(tensor, tensor_tmp)
     del tensor_tmp
 
 
-def create_kv_caches_with_random(
-    num_blocks: int,
-    block_size: int,
-    num_layers: int,
-    num_heads: int,
-    head_size: int,
-    cache_dtype: Optional[Union[str, torch.dtype]],
-    model_dtype: Optional[Union[str, torch.dtype]] = None,
-    seed: Optional[int] = 0,
-    device: Optional[str] = "cuda",
-) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
-    torch.random.manual_seed(seed)
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-
+def get_kv_cache_torch_dtype(
+        cache_dtype: Optional[Union[str, torch.dtype]],
+        model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
     if isinstance(cache_dtype, str):
         if cache_dtype == "auto":
             if isinstance(model_dtype, str):
@@ -340,11 +475,74 @@ def create_kv_caches_with_random(
         torch_dtype = cache_dtype
     else:
         raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
+    return torch_dtype
+
+
+def create_kv_caches_with_random_flash(
+    num_blocks: int,
+    block_size: int,
+    num_layers: int,
+    num_heads: int,
+    head_size: int,
+    cache_dtype: Optional[Union[str, torch.dtype]],
+    model_dtype: Optional[Union[str, torch.dtype]] = None,
+    seed: int = 0,
+    device: Optional[str] = "cuda",
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+    torch.random.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+
+    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
+    key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
+    scale = head_size**-0.5
+
+    key_caches: List[torch.Tensor] = []
+    value_caches: List[torch.Tensor] = []
+
+    for _ in range(num_layers):
+        key_value_cache = torch.empty(size=key_value_cache_shape,
+                                      dtype=torch_dtype,
+                                      device=device)
+        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
+            key_value_cache.uniform_(-scale, scale)
+        elif cache_dtype == 'fp8':
+            _generate_random_fp8(key_value_cache, -scale, scale)
+        else:
+            raise ValueError(
+                f"Does not support key cache of type {cache_dtype}")
+        key_caches.append(key_value_cache[:, 0])
+        value_caches.append(key_value_cache[:, 1])
+    return key_caches, value_caches
+
+
+def create_kv_caches_with_random(
+    num_blocks: int,
+    block_size: int,
+    num_layers: int,
+    num_heads: int,
+    head_size: int,
+    cache_dtype: Optional[Union[str, torch.dtype]],
+    model_dtype: Optional[Union[str, torch.dtype]] = None,
+    seed: int = 0,
+    device: Optional[str] = "cuda",
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+
+    if cache_dtype == "fp8" and head_size % 16:
+        raise ValueError(
+            f"Does not support key cache of type fp8 with head_size "
+            f"{head_size}")
+
+    torch.random.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+
+    torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
 
     scale = head_size**-0.5
     x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
     key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
-    key_caches = []
+    key_caches: List[torch.Tensor] = []
     for _ in range(num_layers):
         key_cache = torch.empty(size=key_cache_shape,
                                 dtype=torch_dtype,
@@ -359,7 +557,7 @@ def create_kv_caches_with_random(
         key_caches.append(key_cache)
 
     value_cache_shape = (num_blocks, num_heads, head_size, block_size)
-    value_caches = []
+    value_caches: List[torch.Tensor] = []
     for _ in range(num_layers):
         value_cache = torch.empty(size=value_cache_shape,
                                   dtype=torch_dtype,
@@ -389,23 +587,30 @@ def is_pin_memory_available() -> bool:
         print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                            "This may slow down the performance.")
         return False
+    elif is_xpu():
+        print_warning_once("Pin memory is not supported on XPU.")
+        return False
     elif is_neuron():
         print_warning_once("Pin memory is not supported on Neuron.")
         return False
-    elif is_cpu():
+    elif is_cpu() or is_openvino():
         return False
     return True
 
 
 class CudaMemoryProfiler:
 
-    def __init__(self, device=None):
+    def __init__(self, device: Optional[torch.types.Device] = None):
         self.device = device
 
     def current_memory_usage(self) -> float:
         # Return the memory usage in bytes.
-        torch.cuda.reset_peak_memory_stats(self.device)
-        mem = torch.cuda.max_memory_allocated(self.device)
+        if torch.cuda.is_available():
+            torch.cuda.reset_peak_memory_stats(self.device)
+            mem = torch.cuda.max_memory_allocated(self.device)
+        elif is_xpu():
+            torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
+            mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
         return mem
 
     def __enter__(self):
@@ -421,7 +626,7 @@ class CudaMemoryProfiler:
         gc.collect()
 
 
-def str_to_int_tuple(s: str) -> Tuple[int]:
+def str_to_int_tuple(s: str) -> Tuple[int, ...]:
     """Convert a string to a tuple of integers."""
     try:
         return tuple(map(int, s.split(",")))
@@ -431,24 +636,54 @@ def str_to_int_tuple(s: str) -> Tuple[int]:
             f"(e.g., 1, 2, 3). Given input: {s}") from e
 
 
-def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
-    assert len(x) <= max_len
-    return x + [pad] * (max_len - len(x))
+def make_ndarray_with_pad(
+    x: List[List[T]],
+    pad: T,
+    dtype: npt.DTypeLike,
+    *,
+    max_len: Optional[int] = None,
+) -> npt.NDArray:
+    """
+    Make a padded array from 2D inputs.
+
+    The padding is applied to the end of each inner list until it reaches
+    `max_len`.
+    """
+    if max_len is None:
+        # Unlike for most functions, map is faster than a genexpr over `len`
+        max_len = max(map(len, x), default=0)
+
+    padded_x = np.full((len(x), max_len), pad, dtype=dtype)
+    for ind, blocktb in enumerate(x):
+        assert len(blocktb) <= max_len
+        padded_x[ind, :len(blocktb)] = blocktb
+
+    return padded_x
 
 
 def make_tensor_with_pad(
-    x: List[List[int]],
-    max_len: int,
-    pad: int,
+    x: List[List[T]],
+    pad: T,
     dtype: torch.dtype,
-    device: Optional[Union[str, torch.device]],
+    *,
+    max_len: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    pin_memory: bool = False,
 ) -> torch.Tensor:
-    """Make a padded tensor of a 2D inputs.
+    """
+    Make a padded tensor from 2D inputs.
+
     The padding is applied to the end of each inner list until it reaches
     `max_len`.
     """
-    padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
-    return torch.tensor(padded_x, dtype=dtype, device=device)
+    np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
+    padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
+
+    tensor = torch.from_numpy(padded_x).to(device)
+    if pin_memory:
+        tensor = tensor.pin_memory()
+
+    return tensor
 
 
 def async_tensor_h2d(
@@ -471,13 +706,18 @@ def maybe_expand_dim(tensor: torch.Tensor,
     return tensor
 
 
-def merge_dicts(dict1: Dict[Any, List[Any]],
-                dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
+def get_dtype_size(dtype: torch.dtype) -> int:
+    """Get the size of the data type in bytes."""
+    return torch.tensor([], dtype=dtype).element_size()
+
+
+def merge_dicts(dict1: Dict[K, List[T]],
+                dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
     """Merge 2 dicts that have key -> List of items.
-    
+
     When a key conflicts, the values in dict1 is prioritized.
     """
-    merged_dict = defaultdict(list)
+    merged_dict: Dict[K, List[T]] = defaultdict(list)
 
     for key, value in dict1.items():
         merged_dict[key].extend(value)
@@ -486,3 +726,302 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
         merged_dict[key].extend(value)
 
     return dict(merged_dict)
+
+
+JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
+                 Tuple["JSONTree[T]", ...], T]
+"""A nested JSON structure where the leaves need not be JSON-serializable."""
+
+
+@overload
+def json_map_leaves(
+    func: Callable[[T], U],
+    value: Dict[str, JSONTree[T]],
+) -> Dict[str, JSONTree[U]]:
+    ...
+
+
+@overload
+def json_map_leaves(
+    func: Callable[[T], U],
+    value: List[JSONTree[T]],
+) -> List[JSONTree[U]]:
+    ...
+
+
+@overload
+def json_map_leaves(
+    func: Callable[[T], U],
+    value: Tuple[JSONTree[T], ...],
+) -> Tuple[JSONTree[U], ...]:
+    ...
+
+
+@overload
+def json_map_leaves(
+    func: Callable[[T], U],
+    value: JSONTree[T],
+) -> JSONTree[U]:
+    ...
+
+
+def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
+    if isinstance(value, dict):
+        return {k: json_map_leaves(func, v) for k, v in value.items()}
+    elif isinstance(value, list):
+        return [json_map_leaves(func, v) for v in value]
+    elif isinstance(value, tuple):
+        return tuple(json_map_leaves(func, v) for v in value)
+    else:
+        return func(value)
+
+
+def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
+    """Flatten a list of lists to a single list."""
+    return [item for sublist in lists for item in sublist]
+
+
+def init_cached_hf_modules() -> None:
+    """
+    Lazy initialization of the Hugging Face modules.
+    """
+    from transformers.dynamic_module_utils import init_hf_modules
+    init_hf_modules()
+
+
+@lru_cache(maxsize=None)
+def find_library(lib_name: str) -> str:
+    """
+    Find the library file in the system.
+    `lib_name` is full filename, with both prefix and suffix.
+    This function resolves `lib_name` to the full path of the library.
+    """
+    # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
+    # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
+    # `/sbin/ldconfig` should exist in all Linux systems.
+    # `/sbin/ldconfig` searches the library in the system
+    libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
+    # each line looks like the following:
+    # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
+    locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
+    # `LD_LIBRARY_PATH` searches the library in the user-defined paths
+    env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
+    if not locs and env_ld_library_path:
+        locs = [
+            os.path.join(dir, lib_name)
+            for dir in env_ld_library_path.split(":")
+            if os.path.exists(os.path.join(dir, lib_name))
+        ]
+    if not locs:
+        raise ValueError(f"Cannot find {lib_name} in the system.")
+    return locs[0]
+
+
+def find_nccl_library() -> str:
+    """
+    We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
+    environment variable, or we find the library file brought by PyTorch.
+    After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
+    found by `ctypes` automatically.
+    """
+    so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
+
+    # manually load the nccl library
+    if so_file:
+        logger.debug("Found nccl from environment variable "
+                     f"APHRODITE_NCCL_SO_PATH={so_file}")
+    else:
+        if torch.version.cuda is not None:
+            so_file = "libnccl.so.2"
+        elif torch.version.hip is not None:
+            so_file = "librccl.so.1"
+        else:
+            raise ValueError("NCCL only supports CUDA and ROCm backends.")
+        logger.debug(f"Found nccl from library {so_file}")
+    return so_file
+
+
+def enable_trace_function_call_for_thread() -> None:
+    if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
+        tmp_dir = tempfile.gettempdir()
+        filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
+                    f"_thread_{threading.get_ident()}_"
+                    f"at_{datetime.datetime.now()}.log").replace(" ", "_")
+        log_path = os.path.join(tmp_dir, "aphrodite",
+                                get_aphrodite_instance_id(), filename)
+        os.makedirs(os.path.dirname(log_path), exist_ok=True)
+        enable_trace_function_call(log_path)
+
+
+def identity(value: T) -> T:
+    return value
+
+
+F = TypeVar('F', bound=Callable[..., Any])
+
+
+def deprecate_kwargs(
+        *kws: str,
+        is_deprecated: Union[bool, Callable[[], bool]] = True,
+        additional_message: Optional[str] = None) -> Callable[[F], F]:
+    deprecated_kws = set(kws)
+
+    if not callable(is_deprecated):
+        is_deprecated = partial(identity, is_deprecated)
+
+    def wrapper(fn: F) -> F:
+
+        @wraps(fn)
+        def inner(*args, **kwargs):
+            if is_deprecated():
+                deprecated_kwargs = kwargs.keys() & deprecated_kws
+                if deprecated_kwargs:
+                    msg = (
+                        f"The keyword arguments {deprecated_kwargs} are "
+                        "deprecated and will be removed in a future update.")
+                    if additional_message is not None:
+                        msg += f" {additional_message}"
+
+                    warnings.warn(
+                        DeprecationWarning(msg),
+                        stacklevel=3,  # The inner function takes up one level
+                    )
+
+            return fn(*args, **kwargs)
+
+        return inner  # type: ignore
+
+    return wrapper
+
+
+@lru_cache(maxsize=8)
+def _cuda_device_count_stateless(
+        cuda_visible_devices: Optional[str] = None) -> int:
+    # Note: cuda_visible_devices is not used, but we keep it as an argument for
+    # LRU Cache purposes.
+
+    # Code below is based on
+    # https://github.com/pytorch/pytorch/blob/
+    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
+    # torch/cuda/__init__.py#L831C1-L831C17
+    import torch.cuda
+    import torch.version
+
+    if not torch.cuda._is_compiled():
+        return 0
+    if is_hip():
+        # ROCm uses amdsmi instead of nvml for stateless device count
+        # This requires a sufficiently modern version of Torch 2.4.0
+        raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
+            torch.cuda, "_device_count_amdsmi")) else -1
+    else:
+        raw_count = torch.cuda._device_count_nvml()
+    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
+    return r
+
+
+def cuda_device_count_stateless() -> int:
+    """Get number of CUDA devices, caching based on the value of
+    CUDA_VISIBLE_DEVICES at the time of call.
+    
+    This should be used instead of torch.cuda.device_count()
+    unless CUDA_VISIBLE_DEVICES has already been set to the desired
+    value."""
+
+    # This can be removed and simply replaced with torch.cuda.get_device_count
+    # after https://github.com/pytorch/pytorch/pull/122815 is released.
+
+    return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
+
+
+# NVML utils
+# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
+# all the related functions work on real physical device ids.
+# the major benefit of using NVML is that it will not initialize CUDA
+
+try:
+    import pynvml
+except ImportError:
+    # For non-NV devices
+    pynvml = None
+
+
+def with_nvml_context(fn):
+
+    @wraps(fn)
+    def wrapper(*args, **kwargs):
+        if pynvml is not None:
+            pynvml.nvmlInit()
+        try:
+            return fn(*args, **kwargs)
+        finally:
+            if pynvml is not None:
+                pynvml.nvmlShutdown()
+
+    return wrapper
+
+
+@with_nvml_context
+def is_full_nvlink(device_ids: List[int]) -> bool:
+    """
+    query if the set of gpus are fully connected by nvlink (1 hop)
+    """
+    handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
+    for i, handle in enumerate(handles):
+        for j, peer_handle in enumerate(handles):
+            if i < j:
+                try:
+                    p2p_status = pynvml.nvmlDeviceGetP2PStatus(
+                        handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
+                    if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+                        return False
+                except pynvml.NVMLError as error:
+                    logger.error(
+                        "NVLink detection failed. This is normal if your"
+                        " machine has no NVLink equipped.",
+                        exc_info=error)
+                    return False
+    return True
+
+
+#From: https://stackoverflow.com/a/4104188/2749989
+def run_once(f):
+
+    def wrapper(*args, **kwargs) -> Any:
+        if not wrapper.has_run:  # type: ignore[attr-defined]
+            wrapper.has_run = True  # type: ignore[attr-defined]
+            return f(*args, **kwargs)
+
+    wrapper.has_run = False  # type: ignore[attr-defined]
+    return wrapper
+
+
+class FlexibleArgumentParser(argparse.ArgumentParser):
+    """ArgumentParser that allows both underscore and dash in names."""
+
+    def parse_args(self, args=None, namespace=None):
+        if args is None:
+            args = sys.argv[1:]
+
+        # Convert underscores to dashes and vice versa in argument names
+        processed_args = []
+        for arg in args:
+            if arg.startswith('--'):
+                if '=' in arg:
+                    key, value = arg.split('=', 1)
+                    key = '--' + key[len('--'):].replace('_', '-')
+                    processed_args.append(f'{key}={value}')
+                else:
+                    processed_args.append('--' +
+                                          arg[len('--'):].replace('_', '-'))
+            else:
+                processed_args.append(arg)
+
+        return super().parse_args(processed_args, namespace)
+
+
+async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
+                              **kwargs):
+    """Utility function to run async task in a lock"""
+    async with lock:
+        return await task(*args, **kwargs)

+ 13 - 191
aphrodite/distributed/communication_op.py

@@ -1,210 +1,32 @@
-from collections import namedtuple
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, Optional, Union
 
 import torch
-from torch.distributed import ProcessGroup
+import torch.distributed
 
-from .parallel_state import (
-    get_tensor_model_parallel_group,
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-    is_pynccl_enabled_for_all_reduce,
-)
+from .parallel_state import get_tp_group
 
 
 def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
-    """All-reduce the input tensor across model parallel group.
-
-    NOTE: This operation will be applied in-place on the input tensor if
-    disable_custom_all_reduce is set to True. Otherwise, this operation may or
-    may not be applied in place depending on whether custom all reduce is
-    invoked for a particular tensor, which further depends on the tensor size
-    and GPU topology.
-    TLDR: always assume this function modifies its input, but use the return
-    value as the output.
-    """
-    from aphrodite.distributed.device_communicators import pynccl_utils
-    from aphrodite.distributed.device_communicators.custom_all_reduce import (
-        custom_all_reduce)
-    # Bypass the function if we are using only 1 GPU.
-    if get_tensor_model_parallel_world_size() == 1:
-        return input_
-    out = custom_all_reduce(input_)
-    if out is not None:
-        return out
-    if is_pynccl_enabled_for_all_reduce():
-        # TODO: support multiple parallel groups.
-        pynccl_utils.all_reduce(input_)
-    else:
-        torch.distributed.all_reduce(input_,
-                                     group=get_tensor_model_parallel_group())
-    return input_
+    """All-reduce the input tensor across model parallel group."""
+    return get_tp_group().all_reduce(input_)
 
 
 def tensor_model_parallel_all_gather(input_: torch.Tensor,
                                      dim: int = -1) -> torch.Tensor:
     """All-gather the input tensor across model parallel group."""
-    world_size = get_tensor_model_parallel_world_size()
-    # Bypass the function if we are using only 1 GPU.
-    if world_size == 1:
-        return input_
-    assert -input_.dim() <= dim < input_.dim(), (
-        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
-    if dim < 0:
-        # Convert negative dim to positive.
-        dim += input_.dim()
-    input_size = input_.size()
-    # Allocate output tensor.
-    output_tensor = torch.empty((world_size, ) + input_size,
-                                dtype=input_.dtype,
-                                device=input_.device)
-    # All-gather.
-    torch.distributed.all_gather_into_tensor(
-        output_tensor, input_, group=get_tensor_model_parallel_group())
-    # Reshape
-    output_tensor = output_tensor.movedim(0, dim)
-    output_tensor = output_tensor.reshape(input_size[:dim] +
-                                          (world_size * input_size[dim], ) +
-                                          input_size[dim + 1:])
-    return output_tensor
+    return get_tp_group().all_gather(input_, dim)
 
 
 def tensor_model_parallel_gather(input_: torch.Tensor,
                                  dst: int = 0,
                                  dim: int = -1) -> torch.Tensor:
-    """Gather the input tensor across model parallel group.
-
-    NOTE: We assume that the input tensor is on the same device across
-    all the ranks.
-    """
-    world_size = get_tensor_model_parallel_world_size()
-    # Bypass the function if we are using only 1 GPU.
-    if world_size == 1:
-        return input_
-    assert -input_.dim() <= dim < input_.dim(), (
-        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
-    if dim < 0:
-        # Convert negative dim to positive.
-        dim += input_.dim()
-    # Allocate output tensor.
-    if get_tensor_model_parallel_rank() == dst:
-        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
-    else:
-        gather_list = None
-    # Gather.
-    torch.distributed.gather(input_,
-                             gather_list,
-                             dst=dst,
-                             group=get_tensor_model_parallel_group())
-    if get_tensor_model_parallel_rank() == dst:
-        output_tensor = torch.cat(gather_list, dim=dim)
-    else:
-        output_tensor = None
-    return output_tensor
-
-
-def broadcast(input_: torch.Tensor,
-              src: int = 0,
-              group: Optional[ProcessGroup] = None):
-    """Broadcast the input tensor."""
-    group = group or torch.distributed.group.WORLD
-    ranks = torch.distributed.get_process_group_ranks(group)
-    assert src in ranks, f"Invalid src rank ({src})"
-
-    # Bypass the function if we are using only 1 GPU.
-    world_size = torch.distributed.get_world_size(group=group)
-    if world_size == 1:
-        return input_
-    # Broadcast.
-    torch.distributed.broadcast(input_, src=src, group=group)
-    return input_
-
+    """Gather the input tensor across model parallel group."""
+    return get_tp_group().gather(input_, dst, dim)
 
-def broadcast_object_list(obj_list: List[Any],
-                          src: int = 0,
-                          group: Optional[ProcessGroup] = None):
-    """Broadcast the input object list."""
-    group = group or torch.distributed.group.WORLD
-    ranks = torch.distributed.get_process_group_ranks(group)
-    assert src in ranks, f"Invalid src rank ({src})"
 
-    # Bypass the function if we are using only 1 GPU.
-    world_size = torch.distributed.get_world_size(group=group)
-    if world_size == 1:
-        return obj_list
-    # Broadcast.
-    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
-    return obj_list
-
-
-TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
-
-
-def broadcast_tensor_dict(
-    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
-    src: int = 0,
-    group: Optional[ProcessGroup] = None,
-) -> Dict[Any, Union[torch.Tensor, Any]]:
-    """Broadcast the input tensor dictionary."""
-    group = group or torch.distributed.group.WORLD
-    ranks = torch.distributed.get_process_group_ranks(group)
-    assert src in ranks, f"Invalid src rank ({src})"
-
-    # Bypass the function if we are using only 1 GPU.
-    world_size = torch.distributed.get_world_size(group=group)
-    if world_size == 1:
+def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
+                                                                Any]]] = None,
+                          src: int = 0):
+    if not torch.distributed.is_initialized():
         return tensor_dict
-
-    rank = torch.distributed.get_rank()
-    if rank == src:
-        assert isinstance(
-            tensor_dict,
-            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
-        metadata_list = []
-        for key, value in tensor_dict.items():
-            if isinstance(value, torch.Tensor):
-                assert value.is_cuda, (
-                    f"Tensor {key}: {value} is not on cuda. Currently we only "
-                    f"support broadcasting tensors on cuda.")
-                metadata_list.append(
-                    (key, TensorMetadata(value.dtype, value.size())))
-            else:
-                metadata_list.append((key, value))
-        torch.distributed.broadcast_object_list([metadata_list],
-                                                src=src,
-                                                group=group)
-        async_handles = []
-        for key, value in metadata_list:
-            if isinstance(value, TensorMetadata):
-                tensor = tensor_dict[key]
-                async_handles.append(
-                    torch.distributed.broadcast(tensor,
-                                                src=src,
-                                                group=group,
-                                                async_op=True))
-        for async_handle in async_handles:
-            async_handle.wait()
-    else:
-        recv_metadata_list = [None]
-        torch.distributed.broadcast_object_list(recv_metadata_list,
-                                                src=src,
-                                                group=group)
-        metadata_list = recv_metadata_list[0]
-        tensor_dict = {}
-        async_handles = []
-        for key, value in metadata_list:  # pylint: disable=not-an-iterable
-            if isinstance(value, TensorMetadata):
-                tensor = torch.empty(value.size,
-                                     dtype=value.dtype,
-                                     device="cuda")
-                async_handle = torch.distributed.broadcast(tensor,
-                                                           src=src,
-                                                           async_op=True,
-                                                           group=group)
-                async_handles.append(async_handle)
-                tensor_dict[key] = tensor
-            else:
-                tensor_dict[key] = value
-        for async_handle in async_handles:
-            async_handle.wait()
-    return tensor_dict
+    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)

+ 163 - 0
aphrodite/distributed/device_communicators/cuda_wrapper.py

@@ -0,0 +1,163 @@
+"""This file is a pure Python wrapper for the cudart library.
+It avoids the need to compile a separate shared library, and is
+convenient for use when we just need to call a few functions.
+"""
+
+import ctypes
+import glob
+import os
+import sys
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+# this line makes it possible to directly load `libcudart.so` using `ctypes`
+import torch  # noqa
+
+# === export types and functions from cudart to Python ===
+# for the original cudart definition, please check
+# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
+
+cudaError_t = ctypes.c_int
+cudaMemcpyKind = ctypes.c_int
+
+
+class cudaIpcMemHandle_t(ctypes.Structure):
+    _fields_ = [("internal", ctypes.c_byte * 128)]
+
+
+@dataclass
+class Function:
+    name: str
+    restype: Any
+    argtypes: List[Any]
+
+
+def get_pytorch_default_cudart_library_path() -> str:
+    # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
+    lib_folder = "cuda_runtime"
+    lib_name = "libcudart.so.*[0-9]"
+    lib_path = None
+    for path in sys.path:
+        nvidia_path = os.path.join(path, "nvidia")
+        if not os.path.exists(nvidia_path):
+            continue
+        candidate_lib_paths = glob.glob(
+            os.path.join(nvidia_path, lib_folder, "lib", lib_name))
+        if candidate_lib_paths and not lib_path:
+            lib_path = candidate_lib_paths[0]
+        if lib_path:
+            break
+    if not lib_path:
+        raise ValueError(f"{lib_name} not found in the system path {sys.path}")
+    return lib_path
+
+
+class CudaRTLibrary:
+    exported_functions = [
+        # ​cudaError_t cudaSetDevice ( int  device )
+        Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
+        # cudaError_t 	cudaDeviceSynchronize ( void )
+        Function("cudaDeviceSynchronize", cudaError_t, []),
+        # ​cudaError_t cudaDeviceReset ( void )
+        Function("cudaDeviceReset", cudaError_t, []),
+
+        # const char* 	cudaGetErrorString ( cudaError_t error )
+        Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
+
+        # ​cudaError_t 	cudaMalloc ( void** devPtr, size_t size )
+        Function("cudaMalloc", cudaError_t,
+                 [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
+        # ​cudaError_t 	cudaFree ( void* devPtr )
+        Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
+        # ​cudaError_t cudaMemset ( void* devPtr, int  value, size_t count )
+        Function("cudaMemset", cudaError_t,
+                 [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
+        # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
+        Function("cudaMemcpy", cudaError_t, [
+            ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
+        ]),
+
+        # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
+        Function("cudaIpcGetMemHandle", cudaError_t,
+                 [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
+        # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int  flags ) # noqa
+        Function("cudaIpcOpenMemHandle", cudaError_t, [
+            ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
+        ]),
+    ]
+
+    # class attribute to store the mapping from the path to the library
+    # to avoid loading the same library multiple times
+    path_to_library_cache: Dict[str, Any] = {}
+
+    # class attribute to store the mapping from library path
+    #  to the corresponding dictionary
+    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
+
+    def __init__(self, so_file: Optional[str] = None):
+        if so_file is None:
+            so_file = get_pytorch_default_cudart_library_path()
+        if so_file not in CudaRTLibrary.path_to_library_cache:
+            lib = ctypes.CDLL(so_file)
+            CudaRTLibrary.path_to_library_cache[so_file] = lib
+        self.lib = CudaRTLibrary.path_to_library_cache[so_file]
+
+        if so_file not in CudaRTLibrary.path_to_dict_mapping:
+            _funcs = {}
+            for func in CudaRTLibrary.exported_functions:
+                f = getattr(self.lib, func.name)
+                f.restype = func.restype
+                f.argtypes = func.argtypes
+                _funcs[func.name] = f
+            CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
+        self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
+
+    def CUDART_CHECK(self, result: cudaError_t) -> None:
+        if result != 0:
+            error_str = self.cudaGetErrorString(result)
+            raise RuntimeError(f"CUDART error: {error_str}")
+
+    def cudaGetErrorString(self, error: cudaError_t) -> str:
+        return self.funcs["cudaGetErrorString"](error).decode("utf-8")
+
+    def cudaSetDevice(self, device: int) -> None:
+        self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
+
+    def cudaDeviceSynchronize(self) -> None:
+        self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
+
+    def cudaDeviceReset(self) -> None:
+        self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
+
+    def cudaMalloc(self, size: int) -> ctypes.c_void_p:
+        devPtr = ctypes.c_void_p()
+        self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
+        return devPtr
+
+    def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
+        self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
+
+    def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
+                   count: int) -> None:
+        self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
+
+    def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
+                   count: int) -> None:
+        cudaMemcpyDefault = 4
+        kind = cudaMemcpyDefault
+        self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
+
+    def cudaIpcGetMemHandle(self,
+                            devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
+        handle = cudaIpcMemHandle_t()
+        self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
+            ctypes.byref(handle), devPtr))
+        return handle
+
+    def cudaIpcOpenMemHandle(self,
+                             handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
+        cudaIpcMemLazyEnablePeerAccess = 1
+        devPtr = ctypes.c_void_p()
+        self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
+            ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
+        return devPtr

+ 187 - 170
aphrodite/distributed/device_communicators/custom_all_reduce.py

@@ -1,155 +1,25 @@
+import os
 from contextlib import contextmanager
-from typing import Optional
-from loguru import logger
+from typing import Any, List, Optional, Union
 
 import torch
 import torch.distributed as dist
+from loguru import logger
+from torch.distributed import ProcessGroup
 
-from aphrodite.distributed import gpu_p2p_access_check
+from aphrodite import _custom_ops as ops
+from aphrodite.common.utils import cuda_device_count_stateless, is_full_nvlink
+from aphrodite.distributed.device_communicators.custom_all_reduce_utils import (
+    gpu_p2p_access_check)
+from aphrodite.distributed.parallel_state import in_the_same_node_as
 
 try:
-    from aphrodite._C import custom_ar
-    import pynvml
-except ImportError:
-    # For AMD GPUs
-    custom_ar = None
-    pynvml = None
-
-_CA_HANDLE = None
-_IS_CAPTURING = False
-_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
-
-
-def init_custom_ar() -> None:
-    from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                       get_tensor_model_parallel_world_size)
-    global _CA_HANDLE
-    if _CA_HANDLE is not None:
-        return
-    rank = get_tensor_model_parallel_rank()
-    world_size = get_tensor_model_parallel_world_size()
-    if world_size == 1:
-        return
-    if world_size not in _SUPPORTED_WORLD_SIZES:
-        logger.warning(
-            "Custom allreduce is disabled due to an unsupported world size: "
-            "%d. Supported world sizes: %s. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.", world_size,
-            str(_SUPPORTED_WORLD_SIZES))
-        return
-    num_dev = torch.cuda.device_count()
-    # note: num dev can be larger than world_size if we're only using
-    # first few GPUs
-    if num_dev < world_size:
-        logger.warning(
-            "Cannot test GPU P2P because not all GPUs are visible to the "
-            "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
-            " is set.")
-        return False
-    # test nvlink first, this will filter out most of the cases
-    # where custom allreduce is not supported
-    full_nvlink = _is_full_nvlink(rank, world_size)
-    if world_size > 2 and not full_nvlink:
-        logger.warning(
-            "Custom allreduce is disabled because it's not supported on more"
-            " than two PCIe-only GPUs. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.")
-        return
-    # test P2P capability
-    # this is expensive to compute at the first time
-    # then we cache the result
-    if not _can_p2p(rank, world_size):
-        logger.warning(
-            "Custom allreduce is disabled because your platform lacks GPU P2P"
-            " capability or P2P test failed. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.")
-        return
-    _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
-
-
-def begin_capture() -> None:
-    global _IS_CAPTURING
-    _IS_CAPTURING = True
-
-
-def end_capture() -> None:
-    global _IS_CAPTURING
-    _IS_CAPTURING = False
-
-
-def is_capturing() -> bool:
-    return _IS_CAPTURING and _CA_HANDLE is not None
-
-
-def get_handle() -> Optional["CustomAllreduce"]:
-    return _CA_HANDLE
-
-
-def is_initialized() -> bool:
-    return _CA_HANDLE is not None
-
-
-@contextmanager
-def capture():
-    try:
-        begin_capture()
-        yield
-    finally:
-        end_capture()
-        handle = get_handle()
-        if handle is not None:
-            handle.register_graph_buffers()
-
-
-# pylint: disable=redefined-builtin
-def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
-    ca_handle = get_handle()
-    # when custom allreduce is disabled, this will be None
-    if ca_handle is None:
-        return
-    if is_capturing():
-        if torch.cuda.is_current_stream_capturing():
-            if ca_handle.should_custom_ar(input):
-                return ca_handle.all_reduce_reg(input)
-        else:
-            if ca_handle.should_custom_ar(input):
-                # if warm up, mimic the allocation pattern
-                # since custom allreduce is out-of-place
-                return torch.empty_like(input)
-    else:
-        # NOTE: outside of cuda graph context,
-        # custom allreduce incurs a cost of cudaMemcpy, which should
-        # be small(<=1% of overall latency) compared to the performance
-        # gains of using custom kernels
-        if ca_handle.should_custom_ar(input):
-            return ca_handle.all_reduce_unreg(input)
-
-
-@contextmanager
-def _nvml():
-    try:
-        pynvml.nvmlInit()
-        yield
-    finally:
-        pynvml.nvmlShutdown()
-
-
-# query if the set of gpus are fully connected by nvlink (1 hop)
-@_nvml()
-def _is_full_nvlink(rank, world_size):
-    handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
-    for i in range(world_size):
-        if i != rank:
-            try:
-                link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
-                if not link_state:
-                    return False
-            except pynvml.NVMLError as error:
-                logger.info(
-                    f"NVLink detection failed with message \"{str(error)}\". "
-                    "This is normal if your machine has no NVLink equipped")
-                return False
-    return True
+    assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
+
+    custom_ar = True
+except Exception:
+    # For AMD GPUs and CPUs
+    custom_ar = False
 
 
 def _can_p2p(rank: int, world_size: int) -> bool:
@@ -163,22 +33,116 @@ def _can_p2p(rank: int, world_size: int) -> bool:
 
 class CustomAllreduce:
 
+    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
+
     # max_size: max supported allreduce size
     def __init__(self,
-                 rank,
-                 world_size,
-                 full_nvlink,
+                 group: ProcessGroup,
+                 device: Union[int, str, torch.device],
                  max_size=8192 * 1024) -> None:
+        """
+        Args:
+            group: the process group to work on. If None, it will use the
+                default process group.
+            device: the device to bind the CustomAllreduce to. If None,
+                it will be bind to f"cuda:{local_rank}".
+        It is the caller's responsibility to make sure each communicator
+        is bind to a unique device, and all communicators in this group
+        are in the same node.
+        """
+        self._IS_CAPTURING = False
+        self.disabled = True
+
+        if not custom_ar:
+            # disable because of missing custom allreduce library
+            # e.g. in a non-cuda environment
+            return
+
+        self.group = group
+
+        assert dist.get_backend(group) != dist.Backend.NCCL, (
+            "CustomAllreduce should be attached to a non-NCCL group.")
+
+        if not all(in_the_same_node_as(group, source_rank=0)):
+            # No need to initialize custom allreduce for multi-node case.
+            logger.warning(
+                "Custom allreduce is disabled because this process group"
+                " spans across nodes.")
+            return
+
+        rank = dist.get_rank(group=self.group)
+        world_size = dist.get_world_size(group=self.group)
+        if world_size == 1:
+            # No need to initialize custom allreduce for single GPU case.
+            return
+
+        if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
+            logger.warning(
+                "Custom allreduce is disabled due to an unsupported world"
+                f" size: {world_size}. Supported world sizes:"
+                f"{str(CustomAllreduce._SUPPORTED_WORLD_SIZES)}. To silence "
+                "this warning, specify disable_custom_all_reduce=True "
+                "explicitly.")
+            return
+
+        if isinstance(device, int):
+            device = torch.device(f"cuda:{device}")
+        elif isinstance(device, str):
+            device = torch.device(device)
+        # now `device` is a `torch.device` object
+        assert isinstance(device, torch.device)
+        self.device = device
+
+        cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if cuda_visible_devices:
+            device_ids = list(map(int, cuda_visible_devices.split(",")))
+        else:
+            device_ids = list(range(cuda_device_count_stateless()))
+
+        physical_device_id = device_ids[device.index]
+        tensor = torch.tensor([physical_device_id],
+                              dtype=torch.int,
+                              device="cpu")
+        gather_list = [
+            torch.tensor([0], dtype=torch.int, device="cpu")
+            for _ in range(world_size)
+        ]
+        dist.all_gather(gather_list, tensor, group=self.group)
+        physical_device_ids = [t.item() for t in gather_list]
+
+        # test nvlink first, this will filter out most of the cases
+        # where custom allreduce is not supported
+        # this checks hardware and driver support for NVLink
+        full_nvlink = is_full_nvlink(physical_device_ids)
+        if world_size > 2 and not full_nvlink:
+            logger.warning(
+                "Custom allreduce is disabled because it's not supported on"
+                " more than two PCIe-only GPUs. To silence this warning, "
+                "specify disable_custom_all_reduce=True explicitly.")
+            return
+        # test P2P capability, this checks software/cudaruntime support
+        # this is expensive to compute at the first time
+        # then we cache the result
+        if not _can_p2p(rank, world_size):
+            logger.warning(
+                "Custom allreduce is disabled because your platform lacks "
+                "GPU P2P capability or P2P test failed. To silence this "
+                "warning, specify disable_custom_all_reduce=True explicitly.")
+            return
+
+        self.disabled = False
         # buffers memory are owned by this Python class and passed to C++
         # meta data composes of two parts: meta data for synchronization
         # (256 bytes) and a temporary buffer for storing intermediate
         # allreduce results.
-        self.meta = torch.zeros(custom_ar.meta_size() + max_size,
+        self.meta = torch.zeros(ops.meta_size() + max_size,
                                 dtype=torch.uint8,
-                                device="cuda")
+                                device=self.device)
         # This is a pre-registered IPC buffer. In eager mode, input tensors
         # are first copied into this buffer before allreduce is performed
-        self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
+        self.buffer = torch.empty(max_size,
+                                  dtype=torch.uint8,
+                                  device=self.device)
         # This is a buffer for storing the tuples of pointers pointing to
         # IPC buffers from all ranks. Each registered tuple has size of
         # 8*world_size bytes where world_size is at most 8. Allocating 8MB
@@ -186,18 +150,32 @@ class CustomAllreduce:
         # needs less than 10000 of registered tuples.
         self.rank_data = torch.empty(8 * 1024 * 1024,
                                      dtype=torch.uint8,
-                                     device="cuda")
+                                     device=self.device)
         self.max_size = max_size
+        self.rank = rank
         self.world_size = world_size
         handles, offsets = self._get_ipc_meta(self.meta)
         self.full_nvlink = full_nvlink
-        self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
-                                             handles, offsets, rank,
-                                             self.full_nvlink)
+        self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
+                                       offsets, rank, self.full_nvlink)
         self.register_buffer(self.buffer)
 
+    @contextmanager
+    def capture(self):
+        """
+        The main responsibility of this context manager is the 
+        `register_graph_buffers` call at the end of the context.
+        It records all the buffer addresses used in the CUDA graph.
+        """
+        try:
+            self._IS_CAPTURING = True
+            yield
+        finally:
+            self._IS_CAPTURING = False
+            if not self.disabled:
+                self.register_graph_buffers()
+
     def _get_ipc_meta(self, inp: torch.Tensor):
-        # pylint: disable=protected-access
         data = inp.untyped_storage()._share_cuda_()
         shard_data = (
             data[1],  # ipc handle to base ptr
@@ -206,48 +184,87 @@ class CustomAllreduce:
         return self._gather_ipc_meta(shard_data)
 
     def _gather_ipc_meta(self, shard_data):
-        all_data = [None] * self.world_size
-        dist.all_gather_object(all_data, shard_data)
+        # Note: don't use `[[None]] * self.world_size` here
+        # because it will create a list of the same reference
+        all_data: List[Optional[Any]] = [[None]
+                                         for i in range(self.world_size)]
+        all_data[self.rank][0] = shard_data
+
+        ranks = dist.get_process_group_ranks(group=self.group)
+        ranks.sort()
+        for i, rank in enumerate(ranks):
+            dist.broadcast_object_list(all_data[i],
+                                       src=rank,
+                                       group=self.group,
+                                       device="cpu")
+
+        # we cannot directly use `dist.all_gather_object` here
+        # because it is incompatible with `gloo` backend under inference mode.
+        # see https://github.com/pytorch/pytorch/issues/126032 for details.
 
         handles = []
         offsets = []
         for i in range(len(all_data)):
-            handles.append(all_data[i][0])
-            offsets.append(all_data[i][1])
+            handles.append(all_data[i][0][0])  # type: ignore
+            offsets.append(all_data[i][0][1])  # type: ignore
         return handles, offsets
 
     def register_buffer(self, inp: torch.Tensor):
         handles, offsets = self._get_ipc_meta(inp)
-        custom_ar.register_buffer(self._ptr, inp, handles, offsets)
+        ops.register_buffer(self._ptr, inp, handles, offsets)
 
     def register_graph_buffers(self):
-        handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
+        handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
         handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
-        logger.info("Registering %d cuda graph addresses", len(offset))
-        custom_ar.register_graph_buffers(self._ptr, handles, offsets)
+        if self.rank == 0:
+            logger.info(f"Registering {len(offset)} cuda graph addresses")
+        ops.register_graph_buffers(self._ptr, handles, offsets)
 
     def should_custom_ar(self, inp: torch.Tensor):
-        return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
-                                          self.full_nvlink)
+        return ops.should_custom_ar(inp, self.max_size, self.world_size,
+                                    self.full_nvlink)
 
     # all reduce, assuming inp tensor is IPC registered with register_buffer,
     # or, in the context of cuda graphs, register_graph_buffers
     def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
         if out is None:
             out = torch.empty_like(inp)
-        custom_ar.all_reduce_reg(self._ptr, inp, out)
+        ops.all_reduce_reg(self._ptr, inp, out)
         return out
 
     # all reduce, assuming inp tensor is NOT IPC registered
     def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
         if out is None:
             out = torch.empty_like(inp)
-        custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
+        ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
         return out
 
+    def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
+        # when custom allreduce is disabled, this will be None
+        if self.disabled:
+            return None
+        if self._IS_CAPTURING:
+            if torch.cuda.is_current_stream_capturing():
+                if self.should_custom_ar(input):
+                    return self.all_reduce_reg(input)
+            else:
+                if self.should_custom_ar(input):
+                    # if warm up, mimic the allocation pattern
+                    # since custom allreduce is out-of-place
+                    return torch.empty_like(input)
+        else:
+            # note: outside of cuda graph context,
+            # custom allreduce incurs a cost of cudaMemcpy, which should
+            # be small(<=1% of overall latency) compared to the performance
+            # gains of using custom kernels
+            if self.should_custom_ar(input):
+                return self.all_reduce_unreg(input)
+
+        return None
+
     def close(self):
-        if self._ptr:
-            custom_ar.dispose(self._ptr)
+        if not self.disabled and self._ptr:
+            ops.dispose(self._ptr)
             self._ptr = 0
 
     def __del__(self):

+ 241 - 0
aphrodite/distributed/device_communicators/custom_all_reduce_utils.py

@@ -0,0 +1,241 @@
+import ctypes
+import json
+import os
+import pickle
+import subprocess
+import sys
+from itertools import product
+from typing import Dict, Optional, Sequence
+
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from loguru import logger
+
+from aphrodite.common.utils import (cuda_device_count_stateless,
+                                    update_environment_variables)
+from aphrodite.distributed.device_communicators.cuda_wrapper import (
+    CudaRTLibrary)
+
+
+def producer(batch_src: Sequence[int],
+             producer_queue,
+             consumer_queue,
+             result_queue,
+             cuda_visible_devices: Optional[str] = None):
+    if cuda_visible_devices is not None:
+        update_environment_variables(
+            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
+
+    lib = CudaRTLibrary()
+    for i in batch_src:
+        lib.cudaSetDevice(i)
+        pointer = lib.cudaMalloc(1024)
+        lib.cudaMemset(pointer, 1, 1024)
+        lib.cudaDeviceSynchronize()
+        handle = lib.cudaIpcGetMemHandle(pointer)
+        producer_queue.put(handle)
+        open_success = consumer_queue.get()
+        if open_success:
+            # use two queues to simulate barrier
+            producer_queue.put(0)
+            consumer_queue.get()
+            # check if the memory is modified
+            host_data = (ctypes.c_char * 1024)()
+            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore
+            for i in range(1024):
+                if ord(host_data[i]) != 2:
+                    open_success = False
+                    break
+        result_queue.put(open_success)
+        lib.cudaDeviceReset()
+
+
+def consumer(batch_tgt: Sequence[int],
+             producer_queue,
+             consumer_queue,
+             result_queue,
+             cuda_visible_devices: Optional[str] = None):
+    if cuda_visible_devices is not None:
+        update_environment_variables(
+            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
+
+    lib = CudaRTLibrary()
+    for j in batch_tgt:
+        lib.cudaSetDevice(j)
+        handle = producer_queue.get()
+        open_success = False
+        try:
+            pointer = lib.cudaIpcOpenMemHandle(handle)  # type: ignore
+            open_success = True
+        except RuntimeError:
+            # cannot error out here, because the producer process
+            # is still waiting for the response.
+            pass
+        consumer_queue.put(open_success)
+        if open_success:
+            # modify the memory
+            lib.cudaMemset(pointer, 2, 1024)
+            lib.cudaDeviceSynchronize()
+            # use two queues to simulate barrier
+            producer_queue.get()
+            consumer_queue.put(0)
+            # check if the memory is modified
+            host_data = (ctypes.c_char * 1024)()
+            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore
+            for i in range(1024):
+                if ord(host_data[i]) != 2:
+                    open_success = False
+                    break
+        result_queue.put(open_success)
+        lib.cudaDeviceReset()
+
+
+def can_actually_p2p(
+    batch_src: Sequence[int],
+    batch_tgt: Sequence[int],
+):
+    """
+    Usually, checking if P2P access is enabled can be done by
+    `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
+    the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
+    returns `True` even if P2P access is not actually possible.
+    See
+    https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
+    Therefore, we have to perform a real P2P access to check if it is actually
+    possible.
+
+    Note on p2p and cuda IPC:
+    Usually, one process uses one GPU:
+    GPU src --> cuda context src --> tensor src --> process src
+
+    We need to combine p2p and cuda IPC, so that:
+    GPU src --> cuda context src --> tensor src --> process src
+                                      |shared|
+    GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
+    That is to say, process src creates a tensor in GPU src, passes IPC handle to
+    process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
+    tensor in process tgt will be reflected in the tensor in process src, because
+    they are the same memory segment.
+    It is important to note that process tgt accesses the tensor in GPU tgt, not
+    GPU src. That's why we need p2p access.
+
+    The most time-consuming part is the process creation. To avoid creating
+    processes for every pair of GPUs, we use batched testing. We create two
+    processes for testing all pairs of GPUs in batch. The trick is to reset
+    the device after each test (which is not available in PyTorch).
+    """  # noqa
+    cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
+    # pass the CUDA_VISIBLE_DEVICES to the child process
+    # to make sure they see the same set of GPUs
+
+    # make sure the processes are spawned
+    smp = mp.get_context("spawn")
+    producer_queue = smp.Queue()
+    consumer_queue = smp.Queue()
+    result_queue = smp.Queue()
+    p_src = smp.Process(target=producer,
+                        args=(batch_src, producer_queue, consumer_queue,
+                              result_queue, cuda_visible_devices))
+    p_tgt = smp.Process(target=consumer,
+                        args=(batch_tgt, producer_queue, consumer_queue,
+                              result_queue, cuda_visible_devices))
+    p_src.start()
+    p_tgt.start()
+    p_src.join()
+    p_tgt.join()
+    result = []
+    for src, tgt in zip(batch_src, batch_tgt):
+        a = result_queue.get()
+        b = result_queue.get()
+        if a != b:
+            logger.warning("Two processes do not agree on the P2P access"
+                           f" status on {src} -> {tgt}, treat as disabled.")
+            result.append(False)
+        else:
+            result.append(a)
+    return result
+
+
+# why do we need this cache?
+# we are testing peer-to-peer (p2p) access between GPUs,across processes.
+# if we test it every time, it will be very slow, because we need to create
+#  N * N * 2 processes, where N is the world size. This is very slow.
+# to reduce the time, we use a cache file to store the p2p access status.
+# the cache file is generated by the master process if it does not exist.
+# then all the processes can read the cache file to check the p2p access status.
+# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
+#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,
+#  e.g. used by different aphrodite engines. The device id in the cache file is
+#  a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
+#  of visible devices in the aphrodite engine.
+_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
+
+
+def gpu_p2p_access_check(src: int, tgt: int) -> bool:
+    """Check if GPU src can access GPU tgt."""
+
+    # if the cache variable is already calculated,
+    # read from the cache instead of checking it again
+    global _gpu_p2p_access_cache
+    if _gpu_p2p_access_cache is not None:
+        return _gpu_p2p_access_cache[f"{src}->{tgt}"]
+
+    is_distributed = dist.is_initialized()
+
+    num_dev = cuda_device_count_stateless()
+    cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
+    if cuda_visible_devices is None:
+        cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
+    APHRODITE_CONFIG_ROOT = os.getenv("APHRODITE_CONFIG_ROOT", "~/.config")
+    path = os.path.expanduser(
+        f"{APHRODITE_CONFIG_ROOT}/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
+    )
+    os.makedirs(os.path.dirname(path), exist_ok=True)
+    from aphrodite.distributed.parallel_state import get_world_group
+    if ((not is_distributed or get_world_group().local_rank == 0)
+            and (not os.path.exists(path))):
+        # only the local master process (with local_rank == 0) can
+        #  enter this block to calculate the cache
+        logger.info(f"generating GPU P2P access cache in {path}")
+        cache = {}
+        ids = list(range(num_dev))
+        # batch of all pairs of GPUs
+        batch_src, batch_tgt = zip(*list(product(ids, ids)))
+        # NOTE: we use `subprocess` rather than `multiprocessing` here
+        # because the caller might not have `if __name__ == "__main__":`,
+        # in that case we cannot use spawn method in multiprocessing.
+        # However, `can_actually_p2p` requires spawn method.
+        # The fix is, we use `subprocess` to call the function,
+        # where we have `if __name__ == "__main__":` in this file.
+        input_bytes = pickle.dumps((batch_src, batch_tgt))
+        returned = subprocess.run([sys.executable, __file__],
+                                  input=input_bytes,
+                                  capture_output=True)
+        # check if the subprocess is successful
+        try:
+            returned.check_returncode()
+        except Exception as e:
+            # wrap raised exception to provide more information
+            raise RuntimeError(
+                f"Error happened when batch testing "
+                f"peer-to-peer access from {batch_src} to {batch_tgt}") from e
+        result = pickle.loads(returned.stdout)
+        for _i, _j, r in zip(batch_src, batch_tgt, result):
+            cache[f"{_i}->{_j}"] = r
+        with open(path, "w") as f:
+            json.dump(cache, f, indent=4)
+    if is_distributed:
+        get_world_group().barrier()
+    logger.debug(f"reading GPU P2P access cache from {path}")
+    with open(path, "r") as f:
+        cache = json.load(f)
+    _gpu_p2p_access_cache = cache
+    return _gpu_p2p_access_cache[f"{src}->{tgt}"]
+
+
+__all__ = ["gpu_p2p_access_check"]
+
+if __name__ == "__main__":
+    batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
+    result = can_actually_p2p(batch_src, batch_tgt)
+    sys.stdout.buffer.write(pickle.dumps(result))

+ 145 - 246
aphrodite/distributed/device_communicators/pynccl.py

@@ -1,269 +1,168 @@
-# This file is a pure Python wrapper for the NCCL library.
-# The main purpose is to use NCCL combined with CUDA graph.
-# Before writing this script, we tried the following approach:
-# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
-#  often gets stuck when initializing the NCCL communicator.
-# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
-#  contains many other potential cuda APIs, that are not allowed during
-#  capturing the CUDA graph. For further details, please check
-# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
-#
-# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
-# doable, but we often encounter issues related with nccl versions, and need
-# to switch between different versions of NCCL. See
-# https://github.com/NVIDIA/nccl/issues/1234 for more details.
-# A C/C++ binding is not flexible enough to handle this. It requires
-# recompilation of the code every time we want to switch between different
-# versions. This current implementation, with a **pure** Python wrapper, is
-# more flexible. We can easily switch between different versions of NCCL by
-# changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
-# variable in the code.
-
-import ctypes
-import datetime
-import logging
-import os
+from contextlib import contextmanager
+from typing import Optional, Union
 
 # ===================== import region =====================
 import torch
 import torch.distributed as dist
-from torch.distributed import ReduceOp
-
-logger = logging.getLogger(__name__)
-
-so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
-
-# manually load the nccl library
-if so_file:
-    logger.info(
-        f"Loading nccl from env variable APHRODITE_NCCL_SO_PATH={so_file}")
-else:
-    if torch.version.cuda is not None:
-        so_file = "libnccl.so.2"
-    elif torch.version.hip is not None:
-        so_file = "librccl.so.1"
-    else:
-        raise ValueError("NCCL only supports CUDA and ROCm backends.")
-    logger.debug(f"Loading nccl from library {so_file}")
-
-try:
-    nccl = ctypes.CDLL(so_file)
-except Exception as e:
-    logger.error(
-        f"Failed to load NCCL library from {so_file} ."
-        "It is expected if you are not running on NVIDIA/AMD GPUs."
-        "Otherwise please set the environment variable APHRODITE_NCCL_SO_PATH"
-        " to point to the correct nccl library path. You can install nccl"
-        " with `conda install nccl` or `pip install nvidia-nccl-cu12`")
-    raise e
-
-# === export types and functions from nccl to Python ===
-# for the original nccl definition, please check
-# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
-
-ncclResult_t = ctypes.c_int
-
-# equivalent to c declaration:
-# ncclResult_t  ncclGetVersion(int *version);
-_c_ncclGetVersion = nccl.ncclGetVersion
-_c_ncclGetVersion.restype = ctypes.c_int
-_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
-
-
-def ncclGetVersion() -> str:
-    version = ctypes.c_int()
-    result = _c_ncclGetVersion(ctypes.byref(version))
-    assert result == 0
-    # something like 21903 --> "2.19.3"
-    version_str = str(version.value)
-    major = version_str[0].lstrip("0")
-    minor = version_str[1:3].lstrip("0")
-    patch = version_str[3:].lstrip("0")
-    return f"{major}.{minor}.{patch}"
-
-
-class NcclUniqueId(ctypes.Structure):
-    _fields_ = [("internal", ctypes.c_byte * 128)]
-
-
-# equivalent to c declaration:
-# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
-_c_ncclGetUniqueId = nccl.ncclGetUniqueId
-_c_ncclGetUniqueId.restype = ctypes.c_int
-_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
-
-
-def ncclGetUniqueId() -> NcclUniqueId:
-    unique_id = NcclUniqueId()
-    result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
-    assert result == 0
-    return unique_id
-
-
-# equivalent to c declaration:
-# ncclResult_t  ncclCommInitRank(
-#   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
-# note that ncclComm_t is a pointer type, so the first argument
-# is a pointer to a pointer
-_c_ncclCommInitRank = nccl.ncclCommInitRank
-_c_ncclCommInitRank.restype = ctypes.c_int
-_c_ncclCommInitRank.argtypes = [
-    ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
-]
+from loguru import logger
+from torch.distributed import ProcessGroup, ReduceOp
 
+from aphrodite.distributed.device_communicators.pynccl_wrapper import (
+    NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
+    ncclRedOpTypeEnum, ncclUniqueId)
 
-# enums
-class ncclDataType_t(ctypes.c_int):
-    ncclInt8 = 0
-    ncclChar = 0
-    ncclUint8 = 1
-    ncclInt32 = 2
-    ncclInt = 2
-    ncclUint32 = 3
-    ncclInt64 = 4
-    ncclUint64 = 5
-    ncclFloat16 = 6
-    ncclHalf = 6
-    ncclFloat32 = 7
-    ncclFloat = 7
-    ncclFloat64 = 8
-    ncclDouble = 8
-    ncclBfloat16 = 9
-    ncclNumTypes = 10
 
-    @classmethod
-    def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
-        if dtype == torch.int8:
-            return cls.ncclInt8
-        if dtype == torch.uint8:
-            return cls.ncclUint8
-        if dtype == torch.int32:
-            return cls.ncclInt32
-        if dtype == torch.int64:
-            return cls.ncclInt64
-        if dtype == torch.float16:
-            return cls.ncclFloat16
-        if dtype == torch.float32:
-            return cls.ncclFloat32
-        if dtype == torch.float64:
-            return cls.ncclFloat64
-        if dtype == torch.bfloat16:
-            return cls.ncclBfloat16
-        raise ValueError(f"Unsupported dtype: {dtype}")
-
-
-class ncclRedOp_t(ctypes.c_int):
-    ncclSum = 0
-    ncclProd = 1
-    ncclMax = 2
-    ncclMin = 3
-    ncclAvg = 4
-    ncclNumOps = 5
-
-    @classmethod
-    def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
-        if op == ReduceOp.SUM:
-            return cls.ncclSum
-        if op == ReduceOp.PRODUCT:
-            return cls.ncclProd
-        if op == ReduceOp.MAX:
-            return cls.ncclMax
-        if op == ReduceOp.MIN:
-            return cls.ncclMin
-        if op == ReduceOp.AVG:
-            return cls.ncclAvg
-        raise ValueError(f"Unsupported op: {op}")
-
-
-# equivalent to c declaration:
-# ncclResult_t  ncclAllReduce(
-#   const void* sendbuff, void* recvbuff, size_t count,
-#   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
-#   udaStream_t stream);
-# note that cudaStream_t is a pointer type, so the last argument is a pointer
-_c_ncclAllReduce = nccl.ncclAllReduce
-_c_ncclAllReduce.restype = ctypes.c_int
-_c_ncclAllReduce.argtypes = [
-    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
-    ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
-]
-
-# equivalent to c declaration:
-# ncclResult_t  ncclCommDestroy(ncclComm_t comm);
-_c_ncclCommDestroy = nccl.ncclCommDestroy
-_c_ncclCommDestroy.restype = ctypes.c_int
-_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
-
-
-class NCCLCommunicator:
+class PyNcclCommunicator:
 
     def __init__(
         self,
-        backend=None,
-        init_method=None,
-        timeout=datetime.timedelta(seconds=10),
-        world_size: int = -1,
-        rank: int = -1,
-        store=None,
-        group_name: str = "",
-        pg_options=None,
-        local_rank: int = -1,
+        group: ProcessGroup,
+        device: Union[int, str, torch.device],
+        library_path: Optional[str] = None,
     ):
-        if not dist.is_initialized():
-            backend = backend or "nccl"
-            assert backend == 'nccl', (
-                "only use nccl backend for starting the NCCL communicator")
-            dist.init_process_group(backend=backend,
-                                    init_method=init_method,
-                                    timeout=timeout,
-                                    world_size=world_size,
-                                    rank=rank,
-                                    store=store,
-                                    group_name=group_name,
-                                    pg_options=pg_options)
-        self.rank = dist.get_rank()
-        self.world_size = dist.get_world_size()
-        if local_rank == -1:
-            local_rank = self.rank
-        self.local_rank = local_rank
-        # don't use these args, as they can be -1
-        # use `self.rank`, `self.local_rank` and `self.world_size` instead
-        del world_size, rank, local_rank
-        torch.cuda.set_device(self.local_rank)
+        """
+        Args:
+            group: the process group to work on. If None, it will use the
+                default process group.
+            device: the device to bind the PyNcclCommunicator to. If None,
+                it will be bind to f"cuda:{local_rank}".
+            library_path: the path to the NCCL library. If None, it will
+                use the default library path.
+        It is the caller's responsibility to make sure each communicator
+        is bind to a unique device.
+        """
+        assert dist.is_initialized()
+        assert dist.get_backend(group) != dist.Backend.NCCL, (
+            "PyNcclCommunicator should be attached to a non-NCCL group.")
+        self.group = group
+        # note: this rank is the rank in the group
+        self.rank = dist.get_rank(group)
+        self.world_size = dist.get_world_size(group)
+
+        # if world_size == 1, no need to create communicator
+        if self.world_size == 1:
+            self.available = False
+            self.disabled = True
+            self.stream = None
+            return
+        try:
+            self.nccl = NCCLLibrary(library_path)
+        except Exception:
+            # disable because of missing NCCL library
+            # e.g. in a non-GPU environment
+            self.available = False
+            self.disabled = True
+            self.stream = None
+            return
+
+        self.available = True
+        self.disabled = False
+
+        logger.debug(f"Aphrodite is using nccl=={self.nccl.ncclGetVersion()}")
+
         if self.rank == 0:
-            self.unique_id = ncclGetUniqueId()
+            # get the unique id from NCCL
+            self.unique_id = self.nccl.ncclGetUniqueId()
         else:
-            self.unique_id = NcclUniqueId()
-        tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
-            self.local_rank)
-        dist.broadcast(tensor, src=0)
-        byte_list = tensor.cpu().tolist()
+            # construct an empty unique id
+            self.unique_id = ncclUniqueId()
+        tensor = torch.ByteTensor(list(self.unique_id.internal))
+        ranks = dist.get_process_group_ranks(group)
+        # arg `src` in `broadcast` is the global rank
+        dist.broadcast(tensor, src=ranks[0], group=group)
+        byte_list = tensor.tolist()
         for i, byte in enumerate(byte_list):
             self.unique_id.internal[i] = byte
-        self.comm = ctypes.c_void_p()
-        result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
-                                     self.unique_id, self.rank)
-        assert result == 0
-        self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
+        if isinstance(device, int):
+            device = torch.device(f"cuda:{device}")
+        elif isinstance(device, str):
+            device = torch.device(device)
+        # now `device` is a `torch.device` object
+        assert isinstance(device, torch.device)
+        self.device = device
+        # nccl communicator and stream will use this device
+        # `torch.cuda.device` is a context manager that changes the
+        # current cuda device to the specified one
+        with torch.cuda.device(device):
+            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
+                self.world_size, self.unique_id, self.rank)
+            self.stream = torch.cuda.Stream()
+
+            # A small all_reduce for warmup.
+            data = torch.zeros(1, device=device)
+            self.all_reduce(data)
+            self.stream.synchronize()
+            del data
+
+        # by default it is disabled, e.g. in profiling models and prefill phase.
+        # to use it, use under `with obj.change_state(enable=True)`, usually
+        # when we are using CUDA graph.
+        self.disabled = True
 
     def all_reduce(self,
                    tensor: torch.Tensor,
                    op: ReduceOp = ReduceOp.SUM,
                    stream=None):
+        if self.disabled:
+            return
+        # nccl communicator created on a specific device
+        # will only work on tensors on the same device
+        # otherwise it will cause "illegal memory access"
+        assert tensor.device == self.device, (
+            f"this nccl communicator is created to work on {self.device}, "
+            f"but the input tensor is on {tensor.device}")
         if stream is None:
             stream = self.stream
-        result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
-                                  ctypes.c_void_p(tensor.data_ptr()),
-                                  tensor.numel(),
-                                  ncclDataType_t.from_torch(tensor.dtype),
-                                  ncclRedOp_t.from_torch(op), self.comm,
-                                  ctypes.c_void_p(stream.cuda_stream))
-        assert result == 0
+        self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
+                                buffer_type(tensor.data_ptr()), tensor.numel(),
+                                ncclDataTypeEnum.from_torch(tensor.dtype),
+                                ncclRedOpTypeEnum.from_torch(op), self.comm,
+                                cudaStream_t(stream.cuda_stream))
+
+    def send(self, tensor: torch.Tensor, dst: int, stream=None):
+        if self.disabled:
+            return
+        assert tensor.device == self.device, (
+            f"this nccl communicator is created to work on {self.device}, "
+            f"but the input tensor is on {tensor.device}")
+        if stream is None:
+            stream = self.stream
+        self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
+                           ncclDataTypeEnum.from_torch(tensor.dtype), dst,
+                           self.comm, cudaStream_t(stream.cuda_stream))
+
+    def recv(self, tensor: torch.Tensor, src: int, stream=None):
+        if self.disabled:
+            return
+        assert tensor.device == self.device, (
+            f"this nccl communicator is created to work on {self.device}, "
+            f"but the input tensor is on {tensor.device}")
+        if stream is None:
+            stream = self.stream
+        self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
+                           ncclDataTypeEnum.from_torch(tensor.dtype), src,
+                           self.comm, cudaStream_t(stream.cuda_stream))
+
+    @contextmanager
+    def change_state(self,
+                     enable: Optional[bool] = None,
+                     stream: Optional[torch.cuda.Stream] = None):
+        """
+        A context manager to change the state of the communicator.
+        """
+        if enable is None:
+            # guess a default value when not specified
+            enable = self.available
+
+        if stream is None:
+            stream = self.stream
+
+        old_disable = self.disabled
+        old_stream = self.stream
+
+        self.stream = stream
+        self.disabled = not enable
+        yield
 
-    def __del__(self):
-        # `dist` module might have been already destroyed
-        if hasattr(dist, 'destroy_process_group'):
-            dist.destroy_process_group()
-        # function might have been already destroyed
-        if _c_ncclCommDestroy is not None:
-            _c_ncclCommDestroy(self.comm)
+        self.disabled = old_disable
+        self.stream = old_stream

+ 0 - 68
aphrodite/distributed/device_communicators/pynccl_utils.py

@@ -1,68 +0,0 @@
-import contextlib
-from typing import Optional
-
-import torch
-from loguru import logger
-from torch.distributed import ReduceOp
-
-try:
-    from aphrodite.distributed.device_communicators.pynccl import (
-        NCCLCommunicator,
-        ncclGetVersion,
-    )
-except Exception as e:
-    # in non-NVIDIA environments, we can't import the nccl module
-    # e.g. when running on machines with AMD GPUs
-    logger.info(f"Failed to import NCCL library: {e}")
-    logger.info("It is expected if you are not running on NVIDIA GPUs.")
-    pass
-
-comm: Optional["NCCLCommunicator"] = None
-
-
-def is_initialized() -> bool:
-    """Returns whether the NCCL backend is initialized."""
-    return comm is not None
-
-
-@contextlib.contextmanager
-def set_pynccl_stream(stream: torch.cuda.Stream):
-    """Set the cuda stream for communication"""
-    try:
-        comm.stream = stream
-        yield
-    finally:
-        pass
-
-
-def init_process_group(world_size: int,
-                       rank: int,
-                       init_method: str,
-                       local_rank: int = -1) -> None:
-    assert not is_initialized()
-    global comm
-    logger.info(f"Aphrodite is using nccl=={ncclGetVersion()}")
-    comm = NCCLCommunicator(init_method=init_method,
-                            world_size=world_size,
-                            local_rank=local_rank,
-                            rank=rank)
-
-
-def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
-    """All-reduces the input tensor across the process group."""
-    assert input_.is_cuda, f"{input_} should be a cuda tensor"
-    comm.all_reduce(input_, op)
-
-
-def destroy_process_group() -> None:
-    global comm
-    comm = None
-
-
-def get_world_size() -> int:
-    """Returns the world size."""
-    return comm.world_size
-
-
-def get_nccl_backend():
-    return comm

+ 275 - 0
aphrodite/distributed/device_communicators/pynccl_wrapper.py

@@ -0,0 +1,275 @@
+# This file is a pure Python wrapper for the NCCL library.
+# The main purpose is to use NCCL combined with CUDA graph.
+# Before writing this script, we tried the following approach:
+# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
+#  often gets stuck when initializing the NCCL communicator.
+# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
+#  contains many other potential cuda APIs, that are not allowed during
+#  capturing the CUDA graph. For further details, please check
+# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
+#
+# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
+# doable, but we often encounter issues related with nccl versions, and need
+# to switch between different versions of NCCL. See
+# https://github.com/NVIDIA/nccl/issues/1234 for more details.
+# A C/C++ binding is not flexible enough to handle this. It requires
+# recompilation of the code every time we want to switch between different
+# versions. This current implementation, with a **pure** Python wrapper, is
+# more flexible. We can easily switch between different versions of NCCL by
+# changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
+# variable in the code.
+
+import ctypes
+import platform
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+import torch
+from loguru import logger
+from torch.distributed import ReduceOp
+
+from aphrodite.common.utils import find_nccl_library
+
+# === export types and functions from nccl to Python ===
+# for the original nccl definition, please check
+# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
+
+ncclResult_t = ctypes.c_int
+ncclComm_t = ctypes.c_void_p
+
+
+class ncclUniqueId(ctypes.Structure):
+    _fields_ = [("internal", ctypes.c_byte * 128)]
+
+
+cudaStream_t = ctypes.c_void_p
+buffer_type = ctypes.c_void_p
+
+ncclDataType_t = ctypes.c_int
+
+
+class ncclDataTypeEnum:
+    ncclInt8 = 0
+    ncclChar = 0
+    ncclUint8 = 1
+    ncclInt32 = 2
+    ncclInt = 2
+    ncclUint32 = 3
+    ncclInt64 = 4
+    ncclUint64 = 5
+    ncclFloat16 = 6
+    ncclHalf = 6
+    ncclFloat32 = 7
+    ncclFloat = 7
+    ncclFloat64 = 8
+    ncclDouble = 8
+    ncclBfloat16 = 9
+    ncclNumTypes = 10
+
+    @classmethod
+    def from_torch(cls, dtype: torch.dtype) -> int:
+        if dtype == torch.int8:
+            return cls.ncclInt8
+        if dtype == torch.uint8:
+            return cls.ncclUint8
+        if dtype == torch.int32:
+            return cls.ncclInt32
+        if dtype == torch.int64:
+            return cls.ncclInt64
+        if dtype == torch.float16:
+            return cls.ncclFloat16
+        if dtype == torch.float32:
+            return cls.ncclFloat32
+        if dtype == torch.float64:
+            return cls.ncclFloat64
+        if dtype == torch.bfloat16:
+            return cls.ncclBfloat16
+        raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+ncclRedOp_t = ctypes.c_int
+
+
+class ncclRedOpTypeEnum:
+    ncclSum = 0
+    ncclProd = 1
+    ncclMax = 2
+    ncclMin = 3
+    ncclAvg = 4
+    ncclNumOps = 5
+
+    @classmethod
+    def from_torch(cls, op: ReduceOp) -> int:
+        if op == ReduceOp.SUM:
+            return cls.ncclSum
+        if op == ReduceOp.PRODUCT:
+            return cls.ncclProd
+        if op == ReduceOp.MAX:
+            return cls.ncclMax
+        if op == ReduceOp.MIN:
+            return cls.ncclMin
+        if op == ReduceOp.AVG:
+            return cls.ncclAvg
+        raise ValueError(f"Unsupported op: {op}")
+
+
+@dataclass
+class Function:
+    name: str
+    restype: Any
+    argtypes: List[Any]
+
+
+class NCCLLibrary:
+    exported_functions = [
+        # const char* ncclGetErrorString(ncclResult_t result)
+        Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
+        # ncclResult_t  ncclGetVersion(int *version);
+        Function("ncclGetVersion", ncclResult_t,
+                 [ctypes.POINTER(ctypes.c_int)]),
+        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
+        Function("ncclGetUniqueId", ncclResult_t,
+                 [ctypes.POINTER(ncclUniqueId)]),
+        # ncclResult_t  ncclCommInitRank(
+        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
+        # note that ncclComm_t is a pointer type, so the first argument
+        # is a pointer to a pointer
+        Function("ncclCommInitRank", ncclResult_t, [
+            ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
+            ctypes.c_int
+        ]),
+        # ncclResult_t  ncclAllReduce(
+        #   const void* sendbuff, void* recvbuff, size_t count,
+        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
+        #   cudaStream_t stream);
+        # note that cudaStream_t is a pointer type, so the last argument
+        # is a pointer
+        Function("ncclAllReduce", ncclResult_t, [
+            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
+            ncclRedOp_t, ncclComm_t, cudaStream_t
+        ]),
+
+        # ncclResult_t  ncclSend(
+        #   const void* sendbuff, size_t count, ncclDataType_t datatype,
+        #   int dest, ncclComm_t comm, cudaStream_t stream);
+        Function("ncclSend", ncclResult_t, [
+            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
+            ncclComm_t, cudaStream_t
+        ]),
+
+        # ncclResult_t  ncclRecv(
+        #   void* recvbuff, size_t count, ncclDataType_t datatype,
+        #   int src, ncclComm_t comm, cudaStream_t stream);
+        Function("ncclRecv", ncclResult_t, [
+            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
+            ncclComm_t, cudaStream_t
+        ]),
+
+        # be cautious! this is a collective call, it will block until all
+        # processes in the communicator have called this function.
+        # because Python object destruction can happen in random order,
+        # it is better not to call it at all.
+        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);
+        Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
+    ]
+
+    # class attribute to store the mapping from the path to the library
+    # to avoid loading the same library multiple times
+    path_to_library_cache: Dict[str, Any] = {}
+
+    # class attribute to store the mapping from library path
+    #  to the corresponding dictionary
+    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
+
+    def __init__(self, so_file: Optional[str] = None):
+
+        so_file = so_file or find_nccl_library()
+
+        try:
+            if so_file not in NCCLLibrary.path_to_dict_mapping:
+                lib = ctypes.CDLL(so_file)
+                NCCLLibrary.path_to_library_cache[so_file] = lib
+            self.lib = NCCLLibrary.path_to_library_cache[so_file]
+        except Exception as e:
+            logger.error(
+                f"Failed to load NCCL library from {so_file} ."
+                "It is expected if you are not running on NVIDIA/AMD GPUs."
+                "Otherwise, the nccl library might not exist, be corrupted "
+                "or it does not support the current platform "
+                f"{platform.platform()}. If you already have the library, "
+                "please set the environment variable APHRODITE_NCCL_SO_PATH"
+                " to point to the correct nccl library path.")
+            raise e
+
+        if so_file not in NCCLLibrary.path_to_dict_mapping:
+            _funcs = {}
+            for func in NCCLLibrary.exported_functions:
+                f = getattr(self.lib, func.name)
+                f.restype = func.restype
+                f.argtypes = func.argtypes
+                _funcs[func.name] = f
+            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
+        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
+
+    def ncclGetErrorString(self, result: ncclResult_t) -> str:
+        return self._funcs["ncclGetErrorString"](result).decode("utf-8")
+
+    def NCCL_CHECK(self, result: ncclResult_t) -> None:
+        if result != 0:
+            error_str = self.ncclGetErrorString(result)
+            raise RuntimeError(f"NCCL error: {error_str}")
+
+    def ncclGetVersion(self) -> str:
+        version = ctypes.c_int()
+        self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
+        version_str = str(version.value)
+        # something like 21903 --> "2.19.3"
+        major = version_str[0].lstrip("0")
+        minor = version_str[1:3].lstrip("0")
+        patch = version_str[3:].lstrip("0")
+        return f"{major}.{minor}.{patch}"
+
+    def ncclGetUniqueId(self) -> ncclUniqueId:
+        unique_id = ncclUniqueId()
+        self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
+            ctypes.byref(unique_id)))
+        return unique_id
+
+    def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
+                         rank: int) -> ncclComm_t:
+        comm = ncclComm_t()
+        self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
+                                                        world_size, unique_id,
+                                                        rank))
+        return comm
+
+    def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
+                      count: int, datatype: int, op: int, comm: ncclComm_t,
+                      stream: cudaStream_t) -> None:
+        # `datatype` actually should be `ncclDataType_t`
+        # and `op` should be `ncclRedOp_t`
+        # both are aliases of `ctypes.c_int`
+        # when we pass int to a function, it will be converted to `ctypes.c_int`
+        # by ctypes automatically
+        self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
+                                                     datatype, op, comm,
+                                                     stream))
+
+    def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
+                 dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
+        self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
+                                                dest, comm, stream))
+
+    def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
+                 src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
+        self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
+                                                comm, stream))
+
+    def ncclCommDestroy(self, comm: ncclComm_t) -> None:
+        self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
+
+
+__all__ = [
+    "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
+    "ncclComm_t", "cudaStream_t", "buffer_type"
+]

+ 478 - 0
aphrodite/distributed/device_communicators/shm_broadcast.py

@@ -0,0 +1,478 @@
+import os
+import pickle
+import time
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from multiprocessing import shared_memory
+from typing import List, Optional
+from unittest.mock import patch
+
+import torch
+import torch.distributed as dist
+from loguru import logger
+from torch.distributed import ProcessGroup
+from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore
+
+from aphrodite.common.utils import get_ip, get_open_port
+
+APHRODITE_RINGBUFFER_WARNING_INTERVAL = os.getenv(
+    "APHRODITE_RINGBUFFER_WARNING_INTERVAL", 60)
+
+# time to wait if the queue is full or empty
+# if we sleep for too short, it will consume too much CPU
+# if we sleep for too long, it will slow down the writer/reader
+# 0.1 us is a good balance
+RINGBUFFER_SLEEP_INTERVAL = 1e-7
+
+
+class ShmRingBuffer:
+
+    def __init__(self,
+                 n_reader: int,
+                 max_chunk_bytes: int,
+                 max_chunks: int,
+                 name: Optional[str] = None):
+        """
+        A shared memory ring buffer implementation for broadcast communication.
+        Essentially, it is a queue where only one will `enqueue` and multiple
+        will `dequeue`. The max size of each item, together with the max number
+        of items that can be stored in the buffer are known in advance.
+        In this case, we don't need to synchronize the access to
+         the buffer.
+        
+        Buffer memory layout:
+                  data                                 metadata
+                    |                                      |
+                    | (current_idx)                        | (current_idx)
+                    v                                      v
+        +-------------------------------+----------------------------------------+
+        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+        +-------------------------------+----------------------------------------+
+        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |
+
+        metadata memory layout: each byte is a flag, the first byte is the written
+        flag, and the rest are reader flags. The flags are set to 0 by default.
+        +--------------+--------------+--------------+-----+--------------+
+        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+        +--------------+--------------+--------------+-----+--------------+
+
+        The state of metadata is as follows:
+
+        (case 1) 0???...???: the block is not written yet, cannot read, can write
+        (case 2) 1000...000: the block is just written, can read, cannot write
+        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
+        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
+
+        State transition for readers:
+
+        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
+        Only after the caller finishes reading the block, the reader can mark the block as read.
+        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
+
+        State transition for writer:
+
+        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
+        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
+        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
+        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
+
+        During creation, `name` is None and the buffer is created. We can pass the
+        created object to other processes by pickling it. The other processes will
+        get the name of the shared memory and open it, so that they can access the
+        same shared memory buffer.
+        """# noqa
+        self.n_reader = n_reader
+        self.metadata_size = 1 + n_reader
+        self.max_chunk_bytes = max_chunk_bytes
+        self.max_chunks = max_chunks
+        self.total_bytes_of_buffer = (self.max_chunk_bytes +
+                                      self.metadata_size) * self.max_chunks
+        self.data_offset = 0
+        self.metadata_offset = self.max_chunk_bytes * self.max_chunks
+
+        if name is None:
+            # we are creating a buffer
+            self.is_creator = True
+            self.shared_memory = shared_memory.SharedMemory(
+                create=True, size=self.total_bytes_of_buffer)
+            # initialize the metadata section to 0
+            with memoryview(self.shared_memory.buf[self.metadata_offset:]
+                            ) as metadata_buffer:
+                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
+        else:
+            # we are opening an existing buffer
+            self.is_creator = False
+            # fix to https://stackoverflow.com/q/62748654/9191338
+            # Python incorrectly tracks shared memory even if it is not
+            # created by the process. The following patch is a workaround.
+            with patch("multiprocessing.resource_tracker.register",
+                       lambda *args, **kwargs: None):
+                try:
+                    self.shared_memory = shared_memory.SharedMemory(name=name)
+                    assert self.shared_memory.size == self.total_bytes_of_buffer  # noqa
+                except FileNotFoundError:
+                    # we might deserialize the object in a different node
+                    # in this case, this object is not used,
+                    # and we should suppress the error
+                    pass
+
+    def __reduce__(self):
+        return (
+            self.__class__,
+            (self.n_reader, self.max_chunk_bytes, self.max_chunks,
+             self.shared_memory.name),
+        )
+
+    def __del__(self):
+        if hasattr(self, "shared_memory"):
+            self.shared_memory.close()
+            if self.is_creator:
+                self.shared_memory.unlink()
+
+    @contextmanager
+    def get_data(self, current_idx: int):
+        start = self.data_offset + current_idx * self.max_chunk_bytes
+        end = start + self.max_chunk_bytes
+        with memoryview(self.shared_memory.buf[start:end]) as buf:
+            yield buf
+
+    @contextmanager
+    def get_metadata(self, current_idx: int):
+        start = self.metadata_offset + current_idx * self.metadata_size
+        end = start + self.metadata_size
+        with memoryview(self.shared_memory.buf[start:end]) as buf:
+            yield buf
+
+
+@dataclass
+class Handle:
+    connect_ip: str
+    local_reader_ranks: List[int] = field(default_factory=list)
+
+    buffer: Optional[ShmRingBuffer] = None
+    local_subscribe_port: Optional[int] = None
+    remote_subscribe_port: Optional[int] = None
+
+
+class MessageQueue:
+
+    def __init__(
+        self,
+        n_reader,  # number of all readers
+        n_local_reader,  # number of local readers through shared memory
+        local_reader_ranks: Optional[List[int]] = None,
+        max_chunk_bytes: int = 1024 * 1024 * 10,
+        max_chunks: int = 10,
+        connect_ip: Optional[str] = None,
+    ):
+        if local_reader_ranks is None:
+            local_reader_ranks = list(range(n_local_reader))
+        else:
+            assert len(local_reader_ranks) == n_local_reader
+        self.n_local_reader = n_local_reader
+        n_remote_reader = n_reader - n_local_reader
+        self.n_remote_reader = n_remote_reader
+
+        if connect_ip is None:
+            connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
+
+        context = Context()
+
+        if n_local_reader > 0:
+            # for local readers, we will:
+            # 1. create a shared memory ring buffer to communicate small data
+            # 2. create a publish-subscribe socket to communicate large data
+            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
+                                        max_chunks)
+
+            # XPUB is very similar to PUB,
+            # except that it can receive subscription messages
+            # to confirm the number of subscribers
+            self.local_socket = context.socket(XPUB)
+            # set the verbose option so that we can receive every subscription
+            # message. otherwise, we will only receive the first subscription
+            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
+            self.local_socket.setsockopt(XPUB_VERBOSE, True)
+            local_subscribe_port = get_open_port()
+            self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
+
+            self.current_idx = 0
+
+        else:
+            self.buffer = None  # type: ignore
+            local_subscribe_port = None
+            self.local_socket = None
+            self.current_idx = -1
+
+        if n_remote_reader > 0:
+            # for remote readers, we will:
+            # create a publish-subscribe socket to communicate large data
+            self.remote_socket = context.socket(XPUB)
+            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
+            remote_subscribe_port = get_open_port()
+            self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
+
+        else:
+            remote_subscribe_port = None
+            self.remote_socket = None
+
+        self._is_writer = True
+        self._is_local_reader = False
+        self.local_reader_rank = -1
+        # rank does not matter for remote readers
+        self._is_remote_reader = False
+
+        self.handle = Handle(
+            connect_ip=connect_ip,
+            local_reader_ranks=local_reader_ranks,
+            buffer=self.buffer,
+            local_subscribe_port=local_subscribe_port,
+            remote_subscribe_port=remote_subscribe_port,
+        )
+
+        logger.debug("Aphrodite message queue communication handle: "
+                     f"{self.handle}")
+
+    def export_handle(self) -> Handle:
+        return self.handle
+
+    @staticmethod
+    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
+        self = MessageQueue.__new__(MessageQueue)
+        self.handle = handle
+        self._is_writer = False
+
+        context = Context()
+
+        if rank in handle.local_reader_ranks:
+            assert handle.buffer is not None
+            self.buffer = handle.buffer
+            self.current_idx = 0
+            self.local_reader_rank = handle.local_reader_ranks.index(rank)
+            self._is_local_reader = True
+            self._is_remote_reader = False
+
+            self.local_socket = context.socket(SUB)
+            self.local_socket.setsockopt_string(SUBSCRIBE, "")
+            self.local_socket.connect(
+                f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
+
+            self.remote_socket = None
+        else:
+            self.buffer = None  # type: ignore
+            self.current_idx = -1
+            self.local_reader_rank = -1
+            self._is_local_reader = False
+            self._is_remote_reader = True
+
+            self.local_socket = None
+
+            self.remote_socket = context.socket(SUB)
+            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
+            self.remote_socket.connect(
+                f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
+
+        return self
+
+    def wait_until_ready(self):
+        """This is a collective operation. All processes (including the
+        readers and the writer) should call this function.
+        """
+        if self._is_writer:
+            # wait for all readers to connect
+
+            # local readers
+            for i in range(self.n_local_reader):
+                # wait for subscription messages from all local readers
+                self.local_socket.recv()
+            if self.n_local_reader > 0:
+                # send a message to all local readers
+                # to make sure the publish channel is working
+                self.local_socket.send(b"READY")
+
+            # remote readers
+            for i in range(self.n_remote_reader):
+                # wait for subscription messages from all remote readers
+                self.remote_socket.recv()
+            if self.n_remote_reader > 0:
+                # send a message to all remote readers
+                # to make sure the publish channel is working
+                self.remote_socket.send(b"READY")
+        elif self._is_local_reader:
+            # wait for the writer to send a message
+            recv = self.local_socket.recv()
+            assert recv == b"READY"
+        elif self._is_remote_reader:
+            # wait for the writer to send a message
+            recv = self.remote_socket.recv()
+            assert recv == b"READY"
+
+    @contextmanager
+    def acquire_write(self):
+        assert self._is_writer, "Only writers can acquire write"
+        start_time = time.monotonic()
+        n_warning = 1
+        while True:
+            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
+                read_count = sum(metadata_buffer[1:])
+                written_flag = metadata_buffer[0]
+                if written_flag and read_count != self.buffer.n_reader:
+                    # this block is written and not read by all readers
+                    # for writers, `self.current_idx` is the next block to write
+                    # if this block is not ready to write,
+                    # we need to wait until it is read by all readers
+
+                    # wait for a while
+                    time.sleep(RINGBUFFER_SLEEP_INTERVAL)
+
+                    # if we wait for a long time, we should warn the user
+                    if time.monotonic(
+                    ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning:  # type: ignore # noqa
+                        logger.warning(
+                            "No available block found in "
+                            f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds")
+                        n_warning += 1
+                    continue
+                # found a block that is either
+                # (1) not written
+                # (2) read by all readers
+
+                # mark the block as not written
+                metadata_buffer[0] = 0
+                # let caller write to the buffer
+                with self.buffer.get_data(self.current_idx) as buf:
+                    yield buf
+
+                # caller has written to the buffer
+                # NOTE: order is important here
+                # first set the read flags to 0
+                # then set the written flag to 1
+                # otherwise, the readers may think they already read the block
+                for i in range(1, self.buffer.n_reader + 1):
+                    # set read flag to 0, meaning it is not read yet
+                    metadata_buffer[i] = 0
+                # mark the block as written
+                metadata_buffer[0] = 1
+                self.current_idx = (self.current_idx +
+                                    1) % self.buffer.max_chunks
+                break
+
+    @contextmanager
+    def acquire_read(self):
+        assert self._is_local_reader, "Only readers can acquire read"
+        start_time = time.monotonic()
+        n_warning = 1
+        while True:
+            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
+                read_flag = metadata_buffer[self.local_reader_rank + 1]
+                written_flag = metadata_buffer[0]
+                if not written_flag or read_flag:
+                    # this block is either
+                    # (1) not written
+                    # (2) already read by this reader
+                    # for readers, `self.current_idx` is the next block to read
+                    # if this block is not ready,
+                    # we need to wait until it is written
+
+                    # wait for a while
+                    time.sleep(RINGBUFFER_SLEEP_INTERVAL)
+
+                    # if we wait for a long time, we should warn the user
+                    if time.monotonic(
+                    ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning:  # type: ignore # noqa
+                        logger.warning(
+                            "No available block found in "
+                            f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds."
+                        )
+                        n_warning += 1
+                    continue
+                # found a block that is not read by this reader
+                # let caller read from the buffer
+                with self.buffer.get_data(self.current_idx) as buf:
+                    yield buf
+
+                # caller has read from the buffer
+                # set the read flag
+                metadata_buffer[self.local_reader_rank + 1] = 1
+                self.current_idx = (self.current_idx +
+                                    1) % self.buffer.max_chunks
+                break
+
+    def enqueue(self, obj):
+        assert self._is_writer, "Only writers can enqueue"
+        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
+        if self.n_local_reader > 0:
+            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
+                with self.acquire_write() as buf:
+                    buf[0] = 1  # overflow
+                self.local_socket.send(serialized_obj)
+            else:
+                with self.acquire_write() as buf:
+                    buf[0] = 0  # not overflow
+                    buf[1:len(serialized_obj) + 1] = serialized_obj
+        if self.n_remote_reader > 0:
+            self.remote_socket.send(serialized_obj)
+
+    def dequeue(self):
+        if self._is_local_reader:
+            with self.acquire_read() as buf:
+                overflow = buf[0] == 1
+                if not overflow:
+                    # no need to know the size of serialized object
+                    # pickle format contains the size information internally
+                    # see https://docs.python.org/3/library/pickle.html
+                    obj = pickle.loads(buf[1:])
+            if overflow:
+                recv = self.local_socket.recv()
+                obj = pickle.loads(recv)
+        elif self._is_remote_reader:
+            recv = self.remote_socket.recv()
+            obj = pickle.loads(recv)
+        else:
+            raise RuntimeError("Only readers can dequeue")
+        return obj
+
+    def broadcast_object(self, obj=None):
+        if self._is_writer:
+            self.enqueue(obj)
+            return obj
+        else:
+            return self.dequeue()
+
+    @staticmethod
+    def create_from_process_group(pg: ProcessGroup,
+                                  max_chunk_bytes,
+                                  max_chunks,
+                                  writer_rank=0) -> "MessageQueue":
+        group_rank = dist.get_rank(pg)
+        group_world_size = dist.get_world_size(pg)
+        global_ranks = dist.get_process_group_ranks(pg)
+
+        from aphrodite.distributed.parallel_state import in_the_same_node_as
+        status = in_the_same_node_as(pg, source_rank=writer_rank)
+        same_node_ranks = [i for i, s in enumerate(status) if s]
+        n_reader = group_world_size - 1
+        n_local_reader = len(same_node_ranks) - 1
+        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
+        buffer_io: MessageQueue
+        if group_rank == writer_rank:
+            buffer_io = MessageQueue(
+                n_reader=n_reader,
+                n_local_reader=n_local_reader,
+                local_reader_ranks=local_reader_ranks,
+                max_chunk_bytes=max_chunk_bytes,
+                max_chunks=max_chunks,
+            )
+            handle = buffer_io.export_handle()
+            dist.broadcast_object_list([handle],
+                                       src=global_ranks[writer_rank],
+                                       group=pg)
+        else:
+            recv = [None]
+            dist.broadcast_object_list(recv,
+                                       src=global_ranks[writer_rank],
+                                       group=pg)
+            handle = recv[0]  # type: ignore
+            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
+        buffer_io.wait_until_ready()
+        return buffer_io

+ 31 - 0
aphrodite/distributed/device_communicators/tpu_communicator.py

@@ -0,0 +1,31 @@
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+from aphrodite.platforms import current_platform
+
+if current_platform.is_tpu():
+    import torch_xla.core.xla_model as xm
+    import torch_xla.runtime as xr
+    from torch_xla._internal import pjrt
+
+
+class TpuCommunicator:
+
+    def __init__(self, group: ProcessGroup):
+        if not current_platform.is_tpu():
+            self.disabled = True
+            return
+        self.disabled = False
+
+        local_rank = dist.get_rank(group)
+        world_size = dist.get_world_size(group)
+        pjrt.initialize_multiprocess(local_rank, world_size)
+        xr._init_world_size_ordinal()
+
+    def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
+        return xm.all_reduce(xm.REDUCE_SUM, x)
+
+    def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        assert dim == -1, "TPUs only support dim=-1 for all-gather."
+        return xm.all_gather(x, dim=dim)

+ 993 - 173
aphrodite/distributed/parallel_state.py

@@ -3,48 +3,834 @@
 # Adapted from
 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-"""Tensor and pipeline parallel groups."""
+"""Aphrodite distributed state.
+It takes over the control of the distributed environment from PyTorch.
+The typical workflow is:
+
+- call `init_distributed_environment` to initialize the distributed environment.
+- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to 
+ initialize the model parallel groups.
+
+- any code dealing with the distributed stuff
+
+- call `destroy_model_parallel` to destroy the model parallel groups.
+- call `destroy_distributed_environment` to destroy the distributed environment.
+
+If you only need to use the distributed environment without model/pipeline
+ parallelism, you can skip the model parallel initialization and destruction
+ steps.
+"""
 import contextlib
-from typing import Optional
+import os
+import pickle
+from collections import namedtuple
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
+from multiprocessing import shared_memory
+from typing import Any, Dict, List, Optional, Tuple, Union
+from unittest.mock import patch
 
 import torch
+import torch.distributed
 from loguru import logger
+from torch.distributed import Backend, ProcessGroup
+
+
+@dataclass
+class GraphCaptureContext:
+    stream: torch.cuda.Stream
+
+
+TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
+
+
+def _split_tensor_dict(
+    tensor_dict: Dict[str, Union[torch.Tensor, Any]]
+) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
+    """Split the tensor dictionary into two parts:
+    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
+         by its metadata.
+    2. A list of tensors.
+    """
+    metadata_list: List[Tuple[str, Any]] = []
+    tensor_list: List[torch.Tensor] = []
+    for key, value in tensor_dict.items():
+        if isinstance(value, torch.Tensor):
+            # Note: we cannot use `value.device` here,
+            # because it contains not only the device type but also the device
+            # index (e.g. "cuda:0"). We only need the device type.
+            # receiving side will set the device index.
+            device = value.device.type
+            metadata_list.append(
+                (key, TensorMetadata(device, value.dtype, value.size())))
+            tensor_list.append(value)
+        else:
+            metadata_list.append((key, value))
+    return metadata_list, tensor_list
+
+
+class GroupCoordinator:
+    """
+    PyTorch ProcessGroup wrapper for a group of processes.
+    PyTorch ProcessGroup is bound to one specific communication backend,
+        e.g. NCCL, Gloo, MPI, etc.
+    GroupCoordinator takes charge of all the communication operations among
+        the processes in the group. It can route the communication to
+        a specific implementation (e.g. switch allreduce implementation
+        based on the tensor size and cuda graph mode).
+    """
+
+    # available attributes:
+    rank: int  # global rank
+    ranks: List[int]  # global ranks in the group
+    world_size: int  # size of the group
+    # difference between `local_rank` and `rank_in_group`:
+    # if we have a group of size 4 across two nodes:
+    # Process | Node | Rank | Local Rank | Rank in Group
+    #   0     |   0  |  0   |     0      |       0
+    #   1     |   0  |  1   |     1      |       1
+    #   2     |   1  |  2   |     0      |       2
+    #   3     |   1  |  3   |     1      |       3
+    local_rank: int  # local rank used to assign devices
+    rank_in_group: int  # rank inside the group
+    cpu_group: ProcessGroup  # group for CPU communication
+    device_group: ProcessGroup  # group for device communication
+    use_pynccl: bool  # a hint of whether to use PyNccl
+    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce
+    # communicators are only created for world size > 1
+    pynccl_comm: Optional[Any]  # PyNccl communicator
+    ca_comm: Optional[Any]  # Custom allreduce communicator
+    mq_broadcaster: Optional[Any]  # shared memory broadcaster
+
+    def __init__(
+        self,
+        group_ranks: List[List[int]],
+        local_rank: int,
+        torch_distributed_backend: Union[str, Backend],
+        use_pynccl: bool,
+        use_custom_allreduce: bool,
+        use_tpu_communicator: bool,
+        use_message_queue_broadcaster: bool = False,
+    ):
+
+        self.rank = torch.distributed.get_rank()
+        self.local_rank = local_rank
+        self.device_group = None
+        self.cpu_group = None
+
+        for ranks in group_ranks:
+            device_group = torch.distributed.new_group(
+                ranks, backend=torch_distributed_backend)
+            # a group with `gloo` backend, to allow direct coordination between
+            # processes through the CPU.
+            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
+            if self.rank in ranks:
+                self.ranks = ranks
+                self.world_size = len(ranks)
+                self.rank_in_group = ranks.index(self.rank)
+                self.device_group = device_group
+                self.cpu_group = cpu_group
+
+        assert self.cpu_group is not None
+        assert self.device_group is not None
+
+        if torch.cuda.is_available():
+            self.device = torch.device(f"cuda:{local_rank}")
+        else:
+            self.device = torch.device("cpu")
+
+        self.use_pynccl = use_pynccl
+        self.use_custom_allreduce = use_custom_allreduce
+        self.use_tpu_communicator = use_tpu_communicator
+
+        # lazy import to avoid documentation build error
+        from aphrodite.distributed.device_communicators.custom_all_reduce import (  # noqa: E501
+            CustomAllreduce)
+        from aphrodite.distributed.device_communicators.pynccl import (
+            PyNcclCommunicator)
+
+        self.pynccl_comm: Optional[PyNcclCommunicator]
+        if use_pynccl and self.world_size > 1:
+            self.pynccl_comm = PyNcclCommunicator(
+                group=self.cpu_group,
+                device=self.device,
+            )
+        else:
+            self.pynccl_comm = None
+
+        self.ca_comm: Optional[CustomAllreduce]
+        if use_custom_allreduce and self.world_size > 1:
+            # Initialize a custom fast all-reduce implementation.
+            self.ca_comm = CustomAllreduce(
+                group=self.cpu_group,
+                device=self.device,
+            )
+        else:
+            self.ca_comm = None
+
+        from aphrodite.distributed.device_communicators.tpu_communicator import (  # noqa: E501
+            TpuCommunicator)
+        self.tpu_communicator: Optional[TpuCommunicator]
+        if use_tpu_communicator and self.world_size > 1:
+            self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
+
+        from aphrodite.distributed.device_communicators.shm_broadcast import (
+            MessageQueue)
+        self.mq_broadcaster: Optional[MessageQueue] = None
+        if use_message_queue_broadcaster and self.world_size > 1:
+            self.mq_broadcaster = MessageQueue.create_from_process_group(
+                self.cpu_group, 1 << 22, 6)
+
+    @property
+    def first_rank(self):
+        """Return the global rank of the first process in the group"""
+        return self.ranks[0]
+
+    @property
+    def last_rank(self):
+        """Return the global rank of the last process in the group"""
+        return self.ranks[-1]
+
+    @property
+    def is_first_rank(self):
+        """Return whether the caller is the first process in the group"""
+        return self.rank == self.first_rank
+
+    @property
+    def is_last_rank(self):
+        """Return whether the caller is the last process in the group"""
+        return self.rank == self.last_rank
+
+    @property
+    def next_rank(self):
+        """Return the global rank of the process that follows the caller"""
+        rank_in_group = self.rank_in_group
+        world_size = self.world_size
+        return self.ranks[(rank_in_group + 1) % world_size]
+
+    @property
+    def prev_rank(self):
+        """Return the global rank of the process that precedes the caller"""
+        rank_in_group = self.rank_in_group
+        world_size = self.world_size
+        return self.ranks[(rank_in_group - 1) % world_size]
+
+    @contextmanager
+    def graph_capture(
+            self, graph_capture_context: Optional[GraphCaptureContext] = None):
+        if graph_capture_context is None:
+            stream = torch.cuda.Stream()
+            graph_capture_context = GraphCaptureContext(stream)
+        else:
+            stream = graph_capture_context.stream
+
+        ca_comm = self.ca_comm
+        maybe_ca_context = nullcontext(
+        ) if ca_comm is None else ca_comm.capture()
+
+        # ensure all initialization operations complete before attempting to
+        # capture the graph on another stream
+        curr_stream = torch.cuda.current_stream()
+        if curr_stream != stream:
+            stream.wait_stream(curr_stream)
+
+        with torch.cuda.stream(stream), maybe_ca_context:
+            # In graph mode, we have to be very careful about the collective
+            # operations. The current status is:
+            #     allreduce \ Mode   |  Eager  |  Graph  |
+            # --------------------------------------------
+            # custom allreduce       | enabled | enabled |
+            # PyNccl                 | disabled| enabled |
+            # torch.distributed      | enabled | disabled|
+            #
+            # Note that custom allreduce will have a runtime check, if the
+            #  tensor size is too large, it will fallback to the next
+            #  available option.
+            # In summary: When using CUDA graph, we use
+            #  either custom all-reduce kernel or pynccl. When not using
+            #  CUDA graph, we use either custom all-reduce kernel or
+            #  PyTorch NCCL. We always prioritize using custom all-reduce
+            #  kernel but fall back to PyTorch or pynccl if it is
+            #  disabled or not supported.
+            pynccl_comm = self.pynccl_comm
+            maybe_pynccl_context: Any
+            if not pynccl_comm:
+                maybe_pynccl_context = nullcontext()
+            else:
+                maybe_pynccl_context = pynccl_comm.change_state(
+                    enable=True, stream=torch.cuda.current_stream())
+            with maybe_pynccl_context:
+                yield graph_capture_context
+
+    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
+        """
+        NOTE: This operation will be applied in-place or out-of-place. 
+        Always assume this function modifies its input, but use the return
+        value as the output.
+        """
+        ca_comm = self.ca_comm
+
+        # Bypass the function if we are using only 1 GPU.
+        if self.world_size == 1:
+            return input_
+
+        # For TPUs, use TPU communicator.
+        tpu_comm = self.tpu_communicator
+        if tpu_comm is not None and not tpu_comm.disabled:
+            return tpu_comm.all_reduce(input_)
+
+        if ca_comm is not None:
+            out = ca_comm.custom_all_reduce(input_)
+            if out is not None:
+                return out
+        pynccl_comm = self.pynccl_comm
+        if (pynccl_comm is not None and not pynccl_comm.disabled):
+            pynccl_comm.all_reduce(input_)
+        elif input_.is_cpu:
+            import intel_extension_for_pytorch as ipex
+            ipex.distributed.all_reduce(input_, group=self.device_group)
+        else:
+            torch.distributed.all_reduce(input_, group=self.device_group)
+        return input_
+
+    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
+        world_size = self.world_size
+        # Bypass the function if we are using only 1 GPU.
+        if world_size == 1:
+            return input_
+        assert -input_.dim() <= dim < input_.dim(), (
+            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
+
+        # For TPUs, use TPU communicator.
+        tpu_comm = self.tpu_communicator
+        if tpu_comm is not None and not tpu_comm.disabled:
+            return tpu_comm.all_gather(input_, dim)
+
+        if dim < 0:
+            # Convert negative dim to positive.
+            dim += input_.dim()
+        input_size = input_.size()
+        # Allocate output tensor.
+        output_tensor = torch.empty((world_size, ) + input_size,
+                                    dtype=input_.dtype,
+                                    device=input_.device)
+        # All-gather.
+        torch.distributed.all_gather_into_tensor(output_tensor,
+                                                 input_,
+                                                 group=self.device_group)
+        # Reshape
+        output_tensor = output_tensor.movedim(0, dim)
+        output_tensor = output_tensor.reshape(input_size[:dim] +
+                                              (world_size *
+                                               input_size[dim], ) +
+                                              input_size[dim + 1:])
+        return output_tensor
+
+    def gather(self,
+               input_: torch.Tensor,
+               dst: int = 0,
+               dim: int = -1) -> torch.Tensor:
+        """
+        NOTE: We assume that the input tensor is on the same device across
+        all the ranks.
+        NOTE: `dst` is the local rank of the destination rank.
+        """
+        world_size = self.world_size
+        # Bypass the function if we are using only 1 GPU.
+        if world_size == 1:
+            return input_
+        assert -input_.dim() <= dim < input_.dim(), (
+            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
+        if dim < 0:
+            # Convert negative dim to positive.
+            dim += input_.dim()
+        # Allocate output tensor.
+        if self.rank_in_group == dst:
+            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
+        else:
+            gather_list = None
+        # Gather.
+        torch.distributed.gather(input_,
+                                 gather_list,
+                                 dst=self.ranks[dst],
+                                 group=self.device_group)
+        if self.rank_in_group == dst:
+            output_tensor = torch.cat(gather_list, dim=dim)
+        else:
+            output_tensor = None
+        return output_tensor
+
+    def broadcast(self, input_: torch.Tensor, src: int = 0):
+        """Broadcast the input tensor.
+        NOTE: `src` is the local rank of the source rank.
+        """
+        assert src < self.world_size, f"Invalid src rank ({src})"
+
+        # Bypass the function if we are using only 1 GPU.
+        if self.world_size == 1:
+            return input_
+        # Broadcast.
+        torch.distributed.broadcast(input_,
+                                    src=self.ranks[src],
+                                    group=self.device_group)
+        return input_
+
+    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
+        """Broadcast the input object.
+        NOTE: `src` is the local rank of the source rank.
+        """
+        assert src < self.world_size, f"Invalid src rank ({src})"
+
+        # Bypass the function if we are using only 1 GPU.
+        if self.world_size == 1:
+            return obj
+        if self.mq_broadcaster is not None:
+            assert src == 0, "Message queue broadcaster only supports src=0"
+            return self.mq_broadcaster.broadcast_object(obj)
+        if self.rank_in_group == src:
+            torch.distributed.broadcast_object_list([obj],
+                                                    src=self.ranks[src],
+                                                    group=self.cpu_group)
+            return obj
+        else:
+            recv = [None]
+            torch.distributed.broadcast_object_list(recv,
+                                                    src=self.ranks[src],
+                                                    group=self.cpu_group)
+            return recv[0]
+
+    def broadcast_object_list(self,
+                              obj_list: List[Any],
+                              src: int = 0,
+                              group: Optional[ProcessGroup] = None):
+        """Broadcast the input object list.
+        NOTE: `src` is the local rank of the source rank.
+        """
+        assert src < self.world_size, f"Invalid src rank ({src})"
+
+        # Bypass the function if we are using only 1 GPU.
+        if self.world_size == 1:
+            return obj_list
+        # Broadcast.
+        torch.distributed.broadcast_object_list(obj_list,
+                                                src=self.ranks[src],
+                                                group=self.device_group)
+        return obj_list
+
+    def send_object(self, obj: Any, dst: int) -> None:
+        """Send the input object list to the destination rank."""
+        """NOTE: `dst` is the local rank of the destination rank."""
+
+        assert dst < self.world_size, f"Invalid dst rank ({dst})"
+
+        assert dst != self.rank_in_group, (
+            "Invalid destination rank. Destination rank is the same "
+            "as the current rank.")
+
+        # Serialize object to tensor and get the size as well
+        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
+
+        size_tensor = torch.tensor([object_tensor.numel()],
+                                   dtype=torch.long,
+                                   device="cpu")
+
+        # Send object size
+
+        torch.distributed.send(size_tensor,
+                               dst=self.ranks[dst],
+                               group=self.cpu_group)
+
+        # Send object
+        torch.distributed.send(object_tensor,
+                               dst=self.ranks[dst],
+                               group=self.cpu_group)
+
+        return None
+
+    def recv_object(self, src: int) -> Any:
+        """Receive the input object list from the source rank."""
+        """NOTE: `src` is the local rank of the source rank."""
+
+        assert src < self.world_size, f"Invalid src rank ({src})"
+
+        assert src != self.rank_in_group, (
+            "Invalid source rank. Source rank is the same as the current rank."
+        )
+
+        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
+
+        # Receive object size
+        rank_size = torch.distributed.recv(size_tensor,
+                                           src=self.ranks[src],
+                                           group=self.cpu_group)
+
+        # Tensor to receive serialized objects into.
+        object_tensor = torch.empty(  # type: ignore[call-overload]
+            size_tensor.item(),  # type: ignore[arg-type]
+            dtype=torch.uint8,
+            device="cpu")
+
+        rank_object = torch.distributed.recv(object_tensor,
+                                             src=self.ranks[src],
+                                             group=self.cpu_group)
+
+        assert rank_object == rank_size, (
+            "Received object sender rank does not match the size sender rank.")
+
+        obj = pickle.loads(object_tensor.numpy().tobytes())
+
+        return obj
+
+    def broadcast_tensor_dict(
+        self,
+        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
+        src: int = 0,
+        group: Optional[ProcessGroup] = None,
+        metadata_group: Optional[ProcessGroup] = None
+    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
+        """Broadcast the input tensor dictionary.
+        NOTE: `src` is the local rank of the source rank.
+        """
+        # Bypass the function if we are using only 1 GPU.
+        if (not torch.distributed.is_initialized() or self.world_size == 1):
+            return tensor_dict
+
+        group = self.device_group
+        metadata_group = self.cpu_group
+        assert src < self.world_size, f"Invalid src rank ({src})"
 
-# Tensor model parallel group that the current rank belongs to.
-_TENSOR_MODEL_PARALLEL_GROUP = None
-# Pipeline model parallel group that the current rank belongs to.
-_PIPELINE_MODEL_PARALLEL_GROUP = None
+        rank_in_group = self.rank_in_group
+        if rank_in_group == src:
+            metadata_list: List[Tuple[Any, Any]] = []
+            assert isinstance(
+                tensor_dict,
+                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
+            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+            # `metadata_list` lives in CPU memory.
+            # `broadcast_object_list` has serialization & deserialization,
+            # all happening on CPU. Therefore, we can use the CPU group.
+            self.broadcast_object(metadata_list, src=src)
+            async_handles = []
+            for tensor in tensor_list:
+                if tensor.numel() == 0:
+                    # Skip broadcasting empty tensors.
+                    continue
+                if tensor.is_cpu:
+                    # use metadata_group for CPU tensors
+                    handle = torch.distributed.broadcast(tensor,
+                                                         src=self.ranks[src],
+                                                         group=metadata_group,
+                                                         async_op=True)
+                else:
+                    # use group for GPU tensors
+                    handle = torch.distributed.broadcast(tensor,
+                                                         src=self.ranks[src],
+                                                         group=group,
+                                                         async_op=True)
+                async_handles.append(handle)
+            for async_handle in async_handles:
+                async_handle.wait()
 
-# when people blindly call `torch.distributed.all_reduce` etc,
-# it will use this group. It is initialized with the `backend`
-# parameter of `init_distributed_environment` below.
-# Essentially, this is `torch.distributed.group.WORLD`.
-# We leave a line here to note that this is device-specific.
-# Note that this variable is not safe to use, because when users
-# call `init_distributed_environment` first, and then destroy
-# the process group themselves, this variable will keep a reference to the
-# destroyed process group, which is not useful.
-_DEVICE_WORLD_GROUP = None
+        else:
+            metadata_list = self.broadcast_object(None, src=src)
+            tensor_dict = {}
+            async_handles = []
+            for key, value in metadata_list:
+                if isinstance(value, TensorMetadata):
+                    tensor = torch.empty(value.size,
+                                         dtype=value.dtype,
+                                         device=value.device)
+                    if tensor.numel() == 0:
+                        # Skip broadcasting empty tensors.
+                        tensor_dict[key] = tensor
+                        continue
+                    if tensor.is_cpu:
+                        # use metadata_group for CPU tensors
+                        handle = torch.distributed.broadcast(
+                            tensor,
+                            src=self.ranks[src],
+                            group=metadata_group,
+                            async_op=True)
+                    else:
+                        # use group for GPU tensors
+                        handle = torch.distributed.broadcast(
+                            tensor,
+                            src=self.ranks[src],
+                            group=group,
+                            async_op=True)
+                    async_handles.append(handle)
+                    tensor_dict[key] = tensor
+                else:
+                    tensor_dict[key] = value
+            for async_handle in async_handles:
+                async_handle.wait()
+        return tensor_dict
 
-# duing `init_distributed_environment`, we will also initialize a
-# group with `gloo` backend, to allow direct coordination between
-# processes through the CPU.
-_CPU_WORLD_GROUP = None
+    def send_tensor_dict(
+        self,
+        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
+        dst: Optional[int] = None,
+        all_gather_group: Optional["GroupCoordinator"] = None,
+    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
+        """Send the input tensor dictionary.
+        NOTE: `dst` is the local rank of the source rank.
+        """
+        # Bypass the function if we are using only 1 GPU.
+        if not torch.distributed.is_initialized() or self.world_size == 1:
+            return tensor_dict
 
-# In summary, after calling `init_distributed_environment`, we will
-# always have two groups: one for device-specific (and is the default)
-# and one for CPU. All processes will be part of both groups.
+        all_gather_size = (1 if all_gather_group is None else
+                           all_gather_group.world_size)
+        all_gather_rank = (0 if all_gather_group is None else
+                           all_gather_group.rank_in_group)
 
-# A list of global ranks for each pipeline group to ease calculation of the
-# source rank when broadcasting from the first or last pipeline stage.
-_PIPELINE_GLOBAL_RANKS = None
+        group = self.device_group
+        metadata_group = self.cpu_group
 
-_LOCAL_RANK = -1
+        if dst is None:
+            dst = (self.rank_in_group + 1) % self.world_size
+        assert dst < self.world_size, f"Invalid dst rank ({dst})"
 
+        metadata_list: List[Tuple[Any, Any]] = []
+        assert isinstance(
+            tensor_dict,
+            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
+        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+        # `metadata_list` lives in CPU memory.
+        # `send_object_list` has serialization & deserialization,
+        # all happening on CPU. Therefore, we can use the CPU group.
+        self.send_object(metadata_list, dst=dst)
+        for tensor in tensor_list:
+            if tensor.numel() == 0:
+                # Skip sending empty tensors.
+                continue
 
-def get_local_rank():
-    global _LOCAL_RANK
-    return _LOCAL_RANK
+            # send-allgather: send only a slice, then do allgather.
+            if (all_gather_group is not None
+                    and tensor.numel() % all_gather_size == 0):
+                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
+
+            if tensor.is_cpu:
+                # use metadata_group for CPU tensors
+                torch.distributed.send(tensor,
+                                       dst=self.ranks[dst],
+                                       group=metadata_group)
+            else:
+                # use group for GPU tensors
+                torch.distributed.send(tensor,
+                                       dst=self.ranks[dst],
+                                       group=group)
+        return None
+
+    def recv_tensor_dict(
+        self,
+        src: Optional[int] = None,
+        all_gather_group: Optional["GroupCoordinator"] = None,
+    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
+        """Recv the input tensor dictionary.
+        NOTE: `src` is the local rank of the source rank.
+        """
+        # Bypass the function if we are using only 1 GPU.
+        if not torch.distributed.is_initialized() or self.world_size == 1:
+            return None
+
+        all_gather_size = (1 if all_gather_group is None else
+                           all_gather_group.world_size)
+        all_gather_rank = (0 if all_gather_group is None else
+                           all_gather_group.rank_in_group)
+
+        group = self.device_group
+        metadata_group = self.cpu_group
+
+        if src is None:
+            src = (self.rank_in_group - 1) % self.world_size
+        assert src < self.world_size, f"Invalid src rank ({src})"
+
+        recv_metadata_list = self.recv_object(src=src)
+        tensor_dict: Dict[str, Any] = {}
+        for key, value in recv_metadata_list:
+            if isinstance(value, TensorMetadata):
+                tensor = torch.empty(value.size,
+                                     dtype=value.dtype,
+                                     device=value.device)
+                if tensor.numel() == 0:
+                    # Skip broadcasting empty tensors.
+                    tensor_dict[key] = tensor
+                    continue
+
+                # send-allgather: send only a slice, then do allgather.
+                use_all_gather = (all_gather_group is not None
+                                  and tensor.numel() % all_gather_size == 0)
+
+                if use_all_gather:
+                    orig_shape = tensor.shape
+                    tensor = tensor.reshape(all_gather_size,
+                                            -1)[all_gather_rank]
+
+                if tensor.is_cpu:
+                    # use metadata_group for CPU tensors
+                    torch.distributed.recv(tensor,
+                                           src=self.ranks[src],
+                                           group=metadata_group)
+                else:
+                    # use group for GPU tensors
+                    torch.distributed.recv(tensor,
+                                           src=self.ranks[src],
+                                           group=group)
+                if use_all_gather:
+                    # do the allgather
+                    tensor = all_gather_group.all_gather(  # type: ignore
+                        tensor, dim=0)
+                    tensor = tensor.reshape(orig_shape)
+
+                tensor_dict[key] = tensor
+            else:
+                tensor_dict[key] = value
+        return tensor_dict
+
+    def barrier(self):
+        """Barrier synchronization among the group.
+        NOTE: don't use `device_group` here! `barrier` in NCCL is
+        terrible because it is internally a broadcast operation with
+        secretly created GPU tensors. It is easy to mess up the current
+        device. Use the CPU group instead.
+        """
+        torch.distributed.barrier(group=self.cpu_group)
+
+    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
+        """Sends a tensor to the destination rank in a non-blocking way"""
+        """NOTE: `dst` is the local rank of the destination rank."""
+        if dst is None:
+            dst = (self.rank_in_group + 1) % self.world_size
+
+        pynccl_comm = self.pynccl_comm
+        if pynccl_comm is not None and not pynccl_comm.disabled:
+            pynccl_comm.send(tensor, dst)
+        else:
+            torch.distributed.send(tensor, self.ranks[dst], self.device_group)
+
+    def recv(self,
+             size: torch.Size,
+             dtype: torch.dtype,
+             src: Optional[int] = None) -> torch.Tensor:
+        """Receives a tensor from the src rank."""
+        """NOTE: `src` is the local rank of the destination rank."""
+        if src is None:
+            src = (self.rank_in_group - 1) % self.world_size
+
+        tensor = torch.empty(size, dtype=dtype, device=self.device)
+        pynccl_comm = self.pynccl_comm
+        if pynccl_comm is not None and not pynccl_comm.disabled:
+            pynccl_comm.recv(tensor, src)
+        else:
+            torch.distributed.recv(tensor, self.ranks[src], self.device_group)
+        return tensor
+
+    def destroy(self):
+        if self.device_group is not None:
+            torch.distributed.destroy_process_group(self.device_group)
+            self.device_group = None
+        if self.cpu_group is not None:
+            torch.distributed.destroy_process_group(self.cpu_group)
+            self.cpu_group = None
+        if self.pynccl_comm is not None:
+            self.pynccl_comm = None
+        if self.ca_comm is not None:
+            self.ca_comm = None
+        if self.mq_broadcaster is not None:
+            self.mq_broadcaster = None
+
+
+_WORLD: Optional[GroupCoordinator] = None
+
+
+def get_world_group() -> GroupCoordinator:
+    assert _WORLD is not None, ("world group is not initialized")
+    return _WORLD
+
+
+def init_world_group(ranks: List[int], local_rank: int,
+                     backend: str) -> GroupCoordinator:
+    return GroupCoordinator(
+        group_ranks=[ranks],
+        local_rank=local_rank,
+        torch_distributed_backend=backend,
+        use_pynccl=False,
+        use_custom_allreduce=False,
+        use_tpu_communicator=False,
+    )
+
+
+def init_model_parallel_group(
+    group_ranks: List[List[int]],
+    local_rank: int,
+    backend: str,
+    use_custom_allreduce: Optional[bool] = None,
+    use_message_queue_broadcaster: bool = False,
+) -> GroupCoordinator:
+    if use_custom_allreduce is None:
+        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
+    return GroupCoordinator(
+        group_ranks=group_ranks,
+        local_rank=local_rank,
+        torch_distributed_backend=backend,
+        use_pynccl=True,
+        use_custom_allreduce=use_custom_allreduce,
+        use_tpu_communicator=True,
+        use_message_queue_broadcaster=use_message_queue_broadcaster,
+    )
+
+
+_TP: Optional[GroupCoordinator] = None
+
+
+def get_tp_group() -> GroupCoordinator:
+    assert _TP is not None, ("tensor model parallel group is not initialized")
+    return _TP
+
+
+# kept for backward compatibility
+get_tensor_model_parallel_group = get_tp_group
+
+_PP: Optional[GroupCoordinator] = None
+
+
+def get_pp_group() -> GroupCoordinator:
+    assert _PP is not None, (
+        "pipeline model parallel group is not initialized")
+    return _PP
+
+
+# kept for backward compatibility
+get_pipeline_model_parallel_group = get_pp_group
+
+
+@contextmanager
+def graph_capture():
+    """
+    `graph_capture` is a context manager which should surround the code that
+    is capturing the CUDA graph. Its main purpose is to ensure that the
+    some operations will be run after the graph is captured, before the graph
+    is replayed. It returns a `GraphCaptureContext` object which contains the
+    necessary data for the graph capture. Currently, it only contains the
+    stream that the graph capture is running on. This stream is set to the
+    current CUDA stream when the context manager is entered and reset to the
+    default stream when the context manager is exited. This is to ensure that
+    the graph capture is running on a separate stream from the default stream,
+    in order to explicitly distinguish the kernels to capture
+    from other kernels possibly launched on background in the default stream.
+    """
+    with get_tp_group().graph_capture() as context, get_pp_group(
+    ).graph_capture(context):
+        yield context
+
+
+_ENABLE_CUSTOM_ALL_REDUCE = True
+
+
+def set_custom_all_reduce(enable: bool):
+    global _ENABLE_CUSTOM_ALL_REDUCE
+    _ENABLE_CUSTOM_ALL_REDUCE = enable
 
 
 def init_distributed_environment(
@@ -54,8 +840,9 @@ def init_distributed_environment(
     local_rank: int = -1,
     backend: str = "nccl",
 ):
-    logger.debug(f"{world_size=} {rank=} {local_rank=} "
-                 f"{distributed_init_method=} {backend=}")
+    logger.debug(
+        f"world_size={world_size} rank={rank} local_rank={local_rank} "
+        f"distributed_init_method={distributed_init_method} backend={backend}")
     if not torch.distributed.is_initialized():
         assert distributed_init_method is not None, (
             "distributed_init_method must be provided when initializing "
@@ -66,13 +853,23 @@ def init_distributed_environment(
             init_method=distributed_init_method,
             world_size=world_size,
             rank=rank)
-        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
-        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
+    # set the local rank
+    # local_rank is not available in torch ProcessGroup,
+    # see https://github.com/pytorch/pytorch/issues/122816
+    if local_rank == -1:
+        # local rank not set, this usually happens in single-node
+        # setting, where we can use rank as local rank
+        if distributed_init_method == "env://":
+            local_rank = os.getenv("LOCAL_RANK", rank)
+        else:
+            local_rank = rank
+    global _WORLD
+    if _WORLD is None:
         ranks = list(range(torch.distributed.get_world_size()))
-        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
-                                                       backend="gloo")
-        global _LOCAL_RANK
-        _LOCAL_RANK = local_rank
+        _WORLD = init_world_group(ranks, local_rank, backend)
+    else:
+        assert _WORLD.world_size == torch.distributed.get_world_size(), (
+            "world group already initialized with a different world size")
 
 
 def initialize_model_parallel(
@@ -105,8 +902,8 @@ def initialize_model_parallel(
     # Get world size and rank. Ensure some consistencies.
     assert torch.distributed.is_initialized()
     world_size: int = torch.distributed.get_world_size()
-    # get the backend of _DEVICE_WORLD_GROUP
-    backend = backend or torch.distributed.get_backend()
+    backend = backend or torch.distributed.get_backend(
+        get_world_group().device_group)
 
     if (world_size !=
             tensor_model_parallel_size * pipeline_model_parallel_size):
@@ -115,34 +912,39 @@ def initialize_model_parallel(
             f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
             f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
 
+    # Build the tensor model-parallel groups.
     num_tensor_model_parallel_groups: int = (world_size //
                                              tensor_model_parallel_size)
-    num_pipeline_model_parallel_groups: int = (world_size //
-                                               pipeline_model_parallel_size)
-    rank = torch.distributed.get_rank()
-
-    # Build the tensor model-parallel groups.
-    global _TENSOR_MODEL_PARALLEL_GROUP
-    assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
-        "tensor model parallel group is already initialized")
+    global _TP
+    assert _TP is None, ("tensor model parallel group is already initialized")
+    group_ranks = []
     for i in range(num_tensor_model_parallel_groups):
-        ranks = range(i * tensor_model_parallel_size,
-                      (i + 1) * tensor_model_parallel_size)
-        group = torch.distributed.new_group(ranks, backend=backend)
-        if rank in ranks:
-            _TENSOR_MODEL_PARALLEL_GROUP = group
+        ranks = list(
+            range(i * tensor_model_parallel_size,
+                  (i + 1) * tensor_model_parallel_size))
+        group_ranks.append(ranks)
+
+    # message queue broadcaster is only used in tensor model parallel group
+    _TP = init_model_parallel_group(group_ranks,
+                                    get_world_group().local_rank,
+                                    backend,
+                                    use_message_queue_broadcaster=True)
 
     # Build the pipeline model-parallel groups.
-    global _PIPELINE_MODEL_PARALLEL_GROUP
-    global _PIPELINE_GLOBAL_RANKS
-    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
+    num_pipeline_model_parallel_groups: int = (world_size //
+                                               pipeline_model_parallel_size)
+    global _PP
+    assert _PP is None, (
         "pipeline model parallel group is already initialized")
+    group_ranks = []
     for i in range(num_pipeline_model_parallel_groups):
-        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
-        group = torch.distributed.new_group(ranks, backend=backend)
-        if rank in ranks:
-            _PIPELINE_MODEL_PARALLEL_GROUP = group
-            _PIPELINE_GLOBAL_RANKS = ranks
+        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
+        group_ranks.append(ranks)
+    # pipeline parallel does not need custom allreduce
+    _PP = init_model_parallel_group(group_ranks,
+                                    get_world_group().local_rank,
+                                    backend,
+                                    use_custom_allreduce=False)
 
 
 def ensure_model_parallel_initialized(
@@ -154,8 +956,8 @@ def ensure_model_parallel_initialized(
     or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
     values if the model parallel groups are initialized.
     """
-    # get the backend of _DEVICE_WORLD_GROUP
-    backend = backend or torch.distributed.get_backend()
+    backend = backend or torch.distributed.get_backend(
+        get_world_group().device_group)
     if not model_parallel_is_initialized():
         initialize_model_parallel(tensor_model_parallel_size,
                                   pipeline_model_parallel_size, backend)
@@ -166,149 +968,167 @@ def ensure_model_parallel_initialized(
     ), ("tensor parallel group already initialized, but of unexpected size: "
         f"{get_tensor_model_parallel_world_size()=} vs. "
         f"{tensor_model_parallel_size=}")
-    assert (get_pipeline_model_parallel_world_size(
-    ) == pipeline_model_parallel_size), (
+    pp_world_size = get_pp_group().world_size
+    assert (pp_world_size == pipeline_model_parallel_size), (
         "pipeline parallel group already initialized, but of unexpected size: "
-        f"{get_pipeline_model_parallel_world_size()=} vs. "
+        f"{pp_world_size=} vs. "
         f"{pipeline_model_parallel_size=}")
 
 
 def model_parallel_is_initialized():
     """Check if tensor and pipeline parallel groups are initialized."""
-    return (_TENSOR_MODEL_PARALLEL_GROUP is not None
-            and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
-
+    return (_TP is not None and _PP is not None)
 
-def get_cpu_world_group():
-    """Get the CPU world group."""
-    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
-    return _CPU_WORLD_GROUP
 
+_TP_STATE_PATCHED = False
 
-def get_tensor_model_parallel_group():
-    """Get the tensor model parallel group the caller rank belongs to."""
-    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
-        "tenosr model parallel group is not initialized")
-    return _TENSOR_MODEL_PARALLEL_GROUP
 
+@contextmanager
+def patch_tensor_parallel_group(tp_group: GroupCoordinator):
+    """Patch the tp group temporarily until this function ends.
+    This method is for draft workers of speculative decoding to run draft model
+    with different tp degree from that of target model workers.
+    Args:
+        tp_group (GroupCoordinator): the tp group coordinator
+    """
+    global _TP_STATE_PATCHED
+    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
 
-def get_pipeline_model_parallel_group():
-    """Get the pipeline model parallel group the caller rank belongs to."""
-    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
-        "pipeline model parallel group is not initialized")
-    return _PIPELINE_MODEL_PARALLEL_GROUP
+    _TP_STATE_PATCHED = True
+    old_tp_group = get_tp_group()
+    global _TP
+    _TP = tp_group
+    try:
+        yield
+    finally:
+        # restore the original state
+        _TP_STATE_PATCHED = False
+        _TP = old_tp_group
 
 
 def get_tensor_model_parallel_world_size():
     """Return world size for the tensor model parallel group."""
-    return torch.distributed.get_world_size(
-        group=get_tensor_model_parallel_group())
-
-
-def get_pipeline_model_parallel_world_size():
-    """Return world size for the pipeline model parallel group."""
-    return torch.distributed.get_world_size(
-        group=get_pipeline_model_parallel_group())
+    return get_tp_group().world_size
 
 
 def get_tensor_model_parallel_rank():
     """Return my rank for the tensor model parallel group."""
-    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
+    return get_tp_group().rank_in_group
+
+
+def destroy_model_parallel():
+    """Set the groups to none and destroy them."""
+    global _TP
+    if _TP:
+        _TP.destroy()
+    _TP = None
 
+    global _PP
+    if _PP:
+        _PP.destroy()
+    _PP = None
 
-def get_pipeline_model_parallel_rank():
-    """Return my rank for the pipeline model parallel group."""
-    return torch.distributed.get_rank(
-        group=get_pipeline_model_parallel_group())
 
+def destroy_distributed_environment():
+    global _WORLD
+    if _WORLD:
+        _WORLD.destroy()
+    _WORLD = None
+    if torch.distributed.is_initialized():
+        torch.distributed.destroy_process_group()
 
-def get_tensor_model_parallel_src_rank():
-    """Calculate the global rank corresponding to the first local rank
-    in the tensor model parallel group."""
-    global_rank = torch.distributed.get_rank()
-    local_world_size = get_tensor_model_parallel_world_size()
-    return (global_rank // local_world_size) * local_world_size
 
+def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
+    """
+    This is a collective operation that returns if each rank is in the same node
+    as the source rank. It tests if processes are attached to the same
+    memory system (shared access to shared memory).
+    """
+    assert torch.distributed.get_backend(
+        pg) != torch.distributed.Backend.NCCL, (
+            "in_the_same_node_as should be tested with a non-NCCL group.")
+    # local rank inside the group
+    rank = torch.distributed.get_rank(group=pg)
+    world_size = torch.distributed.get_world_size(group=pg)
 
-def get_pipeline_model_parallel_first_rank():
-    """Return the global rank of the first process in the pipeline for the
-    current tensor parallel group"""
-    assert _PIPELINE_GLOBAL_RANKS is not None, (
-        "Pipeline parallel group is not initialized")
-    return _PIPELINE_GLOBAL_RANKS[0]
+    # local tensor in each process to store the result
+    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
 
+    # global ranks of the processes in the group
+    ranks = torch.distributed.get_process_group_ranks(pg)
 
-def get_pipeline_model_parallel_last_rank():
-    """Return the global rank of the last process in the pipeline for the
-    current tensor parallel group"""
-    assert _PIPELINE_GLOBAL_RANKS is not None, (
-        "Pipeline parallel group is not initialized")
-    last_rank_local = get_pipeline_model_parallel_world_size() - 1
-    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
+    magic_message = b"magic_message"
+    shm = None
 
+    try:
+        with contextlib.suppress(OSError):
+            if rank == source_rank:
+                # create a shared memory segment
+                shm = shared_memory.SharedMemory(create=True, size=128)
+                shm.buf[:len(magic_message)] = magic_message
+                torch.distributed.broadcast_object_list([shm.name],
+                                                        src=ranks[source_rank],
+                                                        group=pg)
+                is_in_the_same_node[rank] = 1
+            else:
+                # try to open the shared memory segment
+                recv = [None]
+                torch.distributed.broadcast_object_list(recv,
+                                                        src=ranks[source_rank],
+                                                        group=pg)
+                name = recv[0]
+                # fix to https://stackoverflow.com/q/62748654/9191338
+                # Python incorrectly tracks shared memory even if it is not
+                # created by the process. The following patch is a workaround.
+                with patch("multiprocessing.resource_tracker.register",
+                           lambda *args, **kwargs: None):
+                    shm = shared_memory.SharedMemory(name=name)
+                if shm.buf[:len(magic_message)] == magic_message:
+                    is_in_the_same_node[rank] = 1
+    except Exception as e:
+        logger.error(f"Error ignored in is_in_the_same_node: {e}")
+    finally:
+        if shm:
+            shm.close()
 
-def get_pipeline_model_parallel_next_rank():
-    """Return the global rank that follows the caller in the pipeline"""
-    assert _PIPELINE_GLOBAL_RANKS is not None, (
-        "Pipeline parallel group is not initialized")
-    rank_in_pipeline = get_pipeline_model_parallel_rank()
-    world_size = get_pipeline_model_parallel_world_size()
-    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
+    torch.distributed.barrier(group=pg)
 
+    # clean up the shared memory segment
+    with contextlib.suppress(OSError):
+        if rank == source_rank and shm:
+            shm.unlink()
+    torch.distributed.all_reduce(is_in_the_same_node, group=pg)
 
-def get_pipeline_model_parallel_prev_rank():
-    """Return the global rank that precedes the caller in the pipeline"""
-    assert _PIPELINE_GLOBAL_RANKS is not None, (
-        "Pipeline parallel group is not initialized")
-    rank_in_pipeline = get_pipeline_model_parallel_rank()
-    world_size = get_pipeline_model_parallel_world_size()
-    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
+    return [x == 1 for x in is_in_the_same_node.tolist()]
 
 
-def destroy_model_parallel():
-    """Set the groups to none and destroy them."""
-    global _TENSOR_MODEL_PARALLEL_GROUP
-    if _TENSOR_MODEL_PARALLEL_GROUP:
-        torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
-    _TENSOR_MODEL_PARALLEL_GROUP = None
-    global _PIPELINE_MODEL_PARALLEL_GROUP
-    if _PIPELINE_MODEL_PARALLEL_GROUP:
-        torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
-    _PIPELINE_MODEL_PARALLEL_GROUP = None
-    global _PIPELINE_GLOBAL_RANKS
-    _PIPELINE_GLOBAL_RANKS = None
-    from aphrodite.distributed.device_communicators import pynccl_utils
-    # Destroy the pynccl states if any.
-    pynccl_utils.destroy_process_group()
-
-
-# Whether to use pynccl for nccl all reduce.
-# We use pynccl for all reduce when using CUDA graph, because torch.distributed
-# is not well supported by CUDA graph.
-_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
-
-
-@contextlib.contextmanager
-def with_pynccl_for_all_reduce():
-    """use Pynccl instead of torch.distributed for all reduce"""
-    from aphrodite.distributed.device_communicators import pynccl_utils
-    tp_size = get_tensor_model_parallel_world_size()
-    if tp_size == 1:
-        # No-op.
-        # NOTE: We don't initialize Pynccl when tp_size is 1.
-        yield
-    else:
-        global _ENABLE_PYNCCL_FOR_ALL_REDUCE
-        old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
-        _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
+def get_current_tp_rank_partition_offset(total_size: int,
+                                         tp_rank: Optional[int] = None,
+                                         tp_size: Optional[int] = None,
+                                         multiple_of: int = 1) -> int:
+    if tp_rank is None:
+        tp_rank = get_tensor_model_parallel_rank()
+
+    if tp_size is None:
+        tp_size = get_tensor_model_parallel_world_size()
+
+    assert total_size % multiple_of == 0
+    total_size = total_size // multiple_of
+    return ((total_size // tp_size) * tp_rank +
+            min(total_size % tp_size, tp_rank)) * multiple_of
+
 
-        stream = torch.cuda.current_stream()
-        with pynccl_utils.set_pynccl_stream(stream):
-            yield
-        _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
+def get_current_tp_rank_partition_size(total_size: int,
+                                       tp_rank: Optional[int] = None,
+                                       tp_size: Optional[int] = None,
+                                       multiple_of: int = 1) -> int:
+    if tp_rank is None:
+        tp_rank = get_tensor_model_parallel_rank()
 
+    if tp_size is None:
+        tp_size = get_tensor_model_parallel_world_size()
 
-def is_pynccl_enabled_for_all_reduce():
-    """check if Pynccl is enabled for all reduce"""
-    global _ENABLE_PYNCCL_FOR_ALL_REDUCE
-    return _ENABLE_PYNCCL_FOR_ALL_REDUCE
+    assert total_size % multiple_of == 0
+    total_size = total_size // multiple_of
+    return ((total_size // tp_size) +
+            (total_size % tp_size > tp_rank)) * multiple_of

+ 33 - 79
aphrodite/distributed/utils.py

@@ -3,15 +3,12 @@
 # Adapted from
 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-import json
 import os
-from typing import Dict, Optional, Sequence
+from typing import Sequence, Tuple
 
 import torch
-import torch.distributed as dist
-from loguru import logger
 
-from .parallel_state import get_cpu_world_group, get_local_rank
+APHRODITE_PP_LAYER_PARTITION = os.getenv("APHRODITE_PP_LAYER_PARTITION", None)
 
 
 def ensure_divisibility(numerator, denominator):
@@ -55,77 +52,34 @@ def split_tensor_along_last_dim(
     return tensor_list
 
 
-# code partly borrowed from
-# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
-# License: MIT
-def _can_actually_p2p(idx_a, idx_b):
-    dev_i = f"cuda:{idx_a}"
-    dev_j = f"cuda:{idx_b}"
-    a = torch.randn(5, device=dev_i) + 123.0
-    b = a.to(dev_j)
-    c = b.to(dev_i)
-    return torch.all(a == c).cpu().item()
-
-
-# why do we need this cache?
-# 1. we can have runtime checks for P2P access, where every process checks
-#  P2P access to all other GPUs. Unfortunately, the test might cost many
-#  (world_size * world_size) cuda context, and reduce the memory available
-#  for the model.
-# 2. alternatively, we can have a p2p map that is generated by the master
-#  process and broadcasted to all other processes. This still requires
-#  #world_size of cuda context, belonging to the master process, on each GPU.
-# 3. we can have a cache file, that records the p2p access status. The first
-#  time the master process checks the p2p access, it will generate the cache
-#  file, at the cost of #world_size of cuda context. Later on, all processes
-#  can read the cache file to check the p2p access status without any cost of
-#  additional cuda context.
-# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
-#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,
-#  e.g. used by different aphrodite engines. The device id in the cache file is
-#  a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
-#  of visible devices in the aphrodite engine.
-_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
-
-
-def gpu_p2p_access_check(i: int, j: int) -> bool:
-    """Check if GPU i can access GPU j."""
-
-    # if the cache variable is already calculated,
-    # read from the cache instead of checking it again
-    global _gpu_p2p_access_cache
-    if _gpu_p2p_access_cache is not None:
-        return _gpu_p2p_access_cache[f"{i}->{j}"]
-
-    is_distributed = dist.is_initialized()
-
-    num_dev = torch.cuda.device_count()
-    cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
-    if cuda_visible_devices is None:
-        cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
-    path = os.path.expanduser(
-        f"~/.config/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
-    )
-    os.makedirs(os.path.dirname(path), exist_ok=True)
-    if (not is_distributed or get_local_rank() == 0) \
-        and (not os.path.exists(path)):
-        # only the local master process (with local_rank == 0) can
-        #  enter this block to calculate the cache
-        logger.info(f"generating GPU P2P access cache for in {path}")
-        cache = {}
-        for _i in range(num_dev):
-            for _j in range(num_dev):
-                # on some platforms, P2P support might be buggy and we need
-                # additional checks.
-                cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
-                    _i, _j) and _can_actually_p2p(_i, _j)
-        with open(path, "w") as f:
-            json.dump(cache, f, indent=4)
-    if is_distributed:
-        cpu_world_group = get_cpu_world_group()
-        dist.barrier(cpu_world_group)
-    logger.info(f"reading GPU P2P access cache from {path}")
-    with open(path, "r") as f:
-        cache = json.load(f)
-    _gpu_p2p_access_cache = cache
-    return _gpu_p2p_access_cache[f"{i}->{j}"]
+def get_pp_indices(num_hidden_layers: int, pp_rank: int,
+                   pp_size: int) -> Tuple[int, int]:
+    """Try to evenly distribute layers across partitions.
+    If the number of layers is not divisible by the number of partitions,
+    the last partition will have the remaining layers.
+    """
+    partition_list_str = APHRODITE_PP_LAYER_PARTITION
+    if partition_list_str is not None:
+        try:
+            partitions = [
+                int(layer) for layer in partition_list_str.split(",")
+            ]
+        except ValueError as err:
+            raise ValueError("Invalid partition string: {}".format(
+                partition_list_str)) from err
+        if len(partitions) != pp_size:
+            raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
+        if sum(partitions) != num_hidden_layers:
+            raise ValueError(
+                f"{sum(partitions)=} does not match {num_hidden_layers=}.")
+        start_layer = sum(partitions[:pp_rank])
+        end_layer = start_layer + partitions[pp_rank]
+    else:
+        layers_per_partition = num_hidden_layers // pp_size
+        start_layer = pp_rank * layers_per_partition
+        end_layer = start_layer + layers_per_partition
+
+        if pp_rank == pp_size - 1:
+            end_layer = num_hidden_layers
+
+    return (start_layer, end_layer)

+ 217 - 0
aphrodite/endpoints/chat_utils.py

@@ -0,0 +1,217 @@
+import codecs
+import tempfile
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
+                    final)
+
+import requests
+from loguru import logger
+# yapf conflicts with isort for this block
+# yapf: disable
+from openai.types.chat import ChatCompletionContentPartImageParam
+from openai.types.chat import (
+    ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
+from openai.types.chat import ChatCompletionContentPartTextParam
+from openai.types.chat import (
+    ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
+# yapf: enable
+# pydantic needs the TypedDict from typing_extensions
+from pydantic import ConfigDict
+from transformers import PreTrainedTokenizer
+from typing_extensions import Required, TypedDict
+
+from aphrodite.common.config import ModelConfig
+from aphrodite.multimodal import MultiModalDataDict
+from aphrodite.multimodal.utils import async_get_and_parse_image
+
+
+class CustomChatCompletionContentPartParam(TypedDict, total=False):
+    __pydantic_config__ = ConfigDict(extra="allow")  # type: ignore
+
+    type: Required[str]
+    """The type of the content part."""
+
+
+ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
+                                       CustomChatCompletionContentPartParam]
+
+
+class CustomChatCompletionMessageParam(TypedDict, total=False):
+    """Enables custom roles in the Chat Completion API."""
+    role: Required[str]
+    """The role of the message's author."""
+
+    content: Union[str, List[ChatCompletionContentPartParam]]
+    """The contents of the message."""
+
+    name: str
+    """An optional name for the participant.
+    Provides the model information to differentiate between participants of the
+    same role.
+    """
+
+
+ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
+                                   CustomChatCompletionMessageParam]
+
+
+@final  # So that it should be compatible with Dict[str, str]
+class ConversationMessage(TypedDict):
+    role: str
+    content: str
+
+
+@dataclass(frozen=True)
+class ChatMessageParseResult:
+    messages: List[ConversationMessage]
+    mm_futures: List[Awaitable[MultiModalDataDict]]
+
+
+def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
+    if chat_template is None:
+        return None
+    try:
+        if chat_template.startswith(('http')):
+            response = requests.get(chat_template)
+            temp = tempfile.NamedTemporaryFile(delete=False)
+            temp.write(response.content)
+            temp.close()
+            chat_template = temp.name
+
+        with open(chat_template, "r") as f:
+            resolved_chat_template = f.read()
+    except OSError as e:
+        JINJA_CHARS = "{}\n"
+        if not any(c in chat_template for c in JINJA_CHARS):
+            msg = (f"The supplied chat template ({chat_template}) "
+                   "looks like a file path, but it failed to be "
+                   f"opened. Reason: {e}")
+            raise ValueError(msg) from e
+
+        # If opening a file fails, set chat template to be args to
+        # ensure we decode so our escape are interpreted correctly
+        resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
+
+    logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
+    return resolved_chat_template
+
+
+@lru_cache(maxsize=None)
+def _image_token_str(model_config: ModelConfig,
+                     tokenizer: PreTrainedTokenizer) -> Optional[str]:
+    # TODO: Let user specify how to insert image tokens into prompt
+    # (similar to chat template)
+    model_type = model_config.hf_config.model_type
+    if model_type == "phi3_v":
+        # Workaround since this token is not defined in the tokenizer
+        return "<|image_1|>"
+    if model_type == "minicpmv":
+        return "(<image>./</image>)"
+    if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
+        # These models do not use image tokens in the prompt
+        return None
+    if model_type.startswith("llava"):
+        return tokenizer.decode(model_config.hf_config.image_token_index)
+    if model_type in ("chameleon", "internvl_chat"):
+        return "<image>"
+
+    raise TypeError(f"Unknown model type: {model_type}")
+
+
+# TODO: Let user specify how to insert image tokens into prompt
+# (similar to chat template)
+def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
+    """Combine image and text prompts for vision language model"""
+
+    # NOTE: For now we assume all model architectures use the same
+    # image + text prompt format. This may change in the future.
+    return f"{image_token_str}\n{text_prompt}"
+
+
+def _parse_chat_message_content_parts(
+    role: str,
+    parts: Iterable[ChatCompletionContentPartParam],
+    model_config: ModelConfig,
+    tokenizer: PreTrainedTokenizer,
+) -> ChatMessageParseResult:
+    texts: List[str] = []
+    mm_futures: List[Awaitable[MultiModalDataDict]] = []
+
+    for part in parts:
+        part_type = part["type"]
+        if part_type == "text":
+            text = cast(ChatCompletionContentPartTextParam, part)["text"]
+            texts.append(text)
+        elif part_type == "image_url":
+            if len(mm_futures) > 0:
+                raise NotImplementedError(
+                    "Multiple 'image_url' input is currently not supported.")
+
+            image_url = cast(ChatCompletionContentPartImageParam,
+                             part)["image_url"]
+
+            if image_url.get("detail", "auto") != "auto":
+                logger.warning(
+                    "'image_url.detail' is currently not supported and "
+                    "will be ignored.")
+
+            image_future = async_get_and_parse_image(image_url["url"])
+            mm_futures.append(image_future)
+        else:
+            raise NotImplementedError(f"Unknown part type: {part_type}")
+
+    text_prompt = "\n".join(texts)
+
+    if mm_futures:
+        image_token_str = _image_token_str(model_config, tokenizer)
+        if image_token_str is not None:
+            if image_token_str in text_prompt:
+                logger.warning(
+                    "Detected image token string in the text prompt. "
+                    "Skipping prompt formatting.")
+            else:
+                text_prompt = _get_full_image_text_prompt(
+                    image_token_str=image_token_str,
+                    text_prompt=text_prompt,
+                )
+
+    messages = [ConversationMessage(role=role, content=text_prompt)]
+
+    return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
+
+
+def _parse_chat_message_content(
+    message: ChatCompletionMessageParam,
+    model_config: ModelConfig,
+    tokenizer: PreTrainedTokenizer,
+) -> ChatMessageParseResult:
+    role = message["role"]
+    content = message.get("content")
+
+    if content is None:
+        return ChatMessageParseResult(messages=[], mm_futures=[])
+    if isinstance(content, str):
+        messages = [ConversationMessage(role=role, content=content)]
+        return ChatMessageParseResult(messages=messages, mm_futures=[])
+
+    return _parse_chat_message_content_parts(role, content, model_config,
+                                             tokenizer)
+
+
+def parse_chat_messages(
+    messages: List[ChatCompletionMessageParam],
+    model_config: ModelConfig,
+    tokenizer: PreTrainedTokenizer,
+) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
+    conversation: List[ConversationMessage] = []
+    mm_futures: List[Awaitable[MultiModalDataDict]] = []
+
+    for msg in messages:
+        parse_result = _parse_chat_message_content(msg, model_config,
+                                                   tokenizer)
+
+        conversation.extend(parse_result.messages)
+        mm_futures.extend(parse_result.mm_futures)
+
+    return conversation, mm_futures

+ 190 - 10
aphrodite/endpoints/cli.py

@@ -1,28 +1,208 @@
+# The CLI entrypoint to Aphrodite.
 import argparse
+import asyncio
+import os
+import signal
+import subprocess
+import sys
+from typing import Optional
 
+import yaml
+from openai import OpenAI
+
+from aphrodite.common.utils import FlexibleArgumentParser
 from aphrodite.endpoints.openai.api_server import run_server
 from aphrodite.endpoints.openai.args import make_arg_parser
 
 
+def registrer_signal_handlers():
+
+    def signal_handler(sig, frame):
+        sys.exit(0)
+
+    signal.signal(signal.SIGINT, signal_handler)
+    signal.signal(signal.SIGTSTP, signal_handler)
+
+
+def serve(args: argparse.Namespace) -> None:
+    # EngineArgs expects the model name to be passed as --model.
+    args.model = args.model_tag
+
+    asyncio.run(run_server(args))
+
+
+def interactive_cli(args: argparse.Namespace) -> None:
+    registrer_signal_handlers()
+
+    base_url = args.url
+    api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
+    openai_client = OpenAI(api_key=api_key, base_url=base_url)
+
+    if args.model_name:
+        model_name = args.model_name
+    else:
+        available_models = openai_client.models.list()
+        model_name = available_models.data[0].id
+
+    print(f"Using model: {model_name}")
+
+    if args.command == "complete":
+        complete(model_name, openai_client)
+    elif args.command == "chat":
+        chat(args.system_prompt, model_name, openai_client)
+
+
+def complete(model_name: str, client: OpenAI) -> None:
+    print("Please enter prompt to complete:")
+    while True:
+        input_prompt = input("> ")
+
+        completion = client.completions.create(model=model_name,
+                                               prompt=input_prompt)
+        output = completion.choices[0].text
+        print(output)
+
+
+def chat(system_prompt: Optional[str], model_name: str,
+         client: OpenAI) -> None:
+    conversation = []
+    if system_prompt is not None:
+        conversation.append({"role": "system", "content": system_prompt})
+
+    print("Please enter a message for the chat model:")
+    while True:
+        input_message = input("> ")
+        message = {"role": "user", "content": input_message}
+        conversation.append(message)
+
+        chat_completion = client.chat.completions.create(model=model_name,
+                                                         messages=conversation)
+
+        response_message = chat_completion.choices[0].message
+        output = response_message.content
+
+        conversation.append(response_message)
+        print(output)
+
+
+STR_BOOLS = ['enforce_eager', 'enable_chunked_prefill']
+ADAPTERS = ['lora_modules', 'prompt_adapters']
+
+
+# TODO: refactor this to directly call run_server with the config file
+def serve_yaml(args: argparse.Namespace) -> None:
+
+    def append_cmd_args(cmd, key, value):
+        if value:  # Skip appending if value is empty
+            if key in ADAPTERS and isinstance(value, list):
+                adapters = [f"{k}={v}" for k, v in value[0].items() if v]
+                if adapters:
+                    cmd.append(f"--{key}")
+                    cmd.extend(adapters)
+            else:
+                cmd.append(f"--{key}")
+                if isinstance(value, bool):
+                    if key in STR_BOOLS:
+                        cmd.append(str(value).lower())
+                    elif value:
+                        cmd.append(str(value))
+                else:
+                    cmd.append(str(value))
+
+    with open(args.config_file, 'r') as f:
+        config = yaml.safe_load(f)
+
+    cmd = ["python", "-m", "aphrodite.endpoints.openai.api_server"]
+    for key, value in config.items():
+        if isinstance(value, list):
+            for item in value:
+                for sub_key, sub_value in item.items():
+                    append_cmd_args(cmd, sub_key, sub_value)
+        else:
+            append_cmd_args(cmd, key, value)
+
+    process = subprocess.Popen(cmd)
+    try:
+        process.wait()
+    except KeyboardInterrupt:
+        process.terminate()
+        process.wait()
+
+
+def _add_query_options(
+        parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
+    parser.add_argument(
+        "--url",
+        type=str,
+        default="http://localhost:2242/v1",
+        help="url of the running OpenAI-Compatible RESTful API server")
+    parser.add_argument(
+        "--model-name",
+        type=str,
+        default=None,
+        help=("The model name used in prompt completion, default to "
+              "the first model in list models API call."))
+    parser.add_argument(
+        "--api-key",
+        type=str,
+        default=None,
+        help=(
+            "API key for OpenAI services. If provided, this api key "
+            "will overwrite the api key obtained through environment variables."
+        ))
+    return parser
+
+
 def main():
-    parser = argparse.ArgumentParser(description="Aphrodite CLI")
-    subparsers = parser.add_subparsers()
+    parser = FlexibleArgumentParser(description="Aphrodite CLI")
+    subparsers = parser.add_subparsers(required=True)
 
     serve_parser = subparsers.add_parser(
         "run",
         help="Start the Aphrodite OpenAI Compatible API server",
         usage="aphrodite run <model_tag> [options]")
-    make_arg_parser(serve_parser)
-    # Override the `--model` optional argument, make it positional.
-    serve_parser.add_argument("model",
+    serve_parser.add_argument("model_tag",
                               type=str,
-                              help="The model tag or path to"
-                              " run.")
-    serve_parser.set_defaults(func=run_server)
+                              help="The model tag to serve")
+    serve_parser = make_arg_parser(serve_parser)
+    serve_parser.set_defaults(dispatch_function=serve)
+
+    complete_parser = subparsers.add_parser(
+        "complete",
+        help=("Generate text completions based on the given prompt "
+              "via the running API server"),
+        usage="aphrodite complete [options]")
+    _add_query_options(complete_parser)
+    complete_parser.set_defaults(dispatch_function=interactive_cli,
+                                 command="complete")
+
+    chat_parser = subparsers.add_parser(
+        "chat",
+        help="Generate chat completions via the running API server",
+        usage="aphrodite chat [options]")
+    _add_query_options(chat_parser)
+    chat_parser.add_argument(
+        "--system-prompt",
+        type=str,
+        default=None,
+        help=("The system prompt to be added to the chat template, "
+              "used for models that support system prompts."))
+    chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")
+
+    yaml_parser = subparsers.add_parser(
+        "yaml",
+        help="Start the Aphrodite OpenAI Compatible API server with a YAML "
+        "config file",
+        usage="aphrodite yaml <config.yaml>")
+    yaml_parser.add_argument("config_file",
+                             type=str,
+                             help="The YAML configuration file to use")
+    yaml_parser.set_defaults(dispatch_function=serve_yaml)
 
     args = parser.parse_args()
-    if hasattr(args, "func"):
-        args.func(args)
+    # One of the sub commands should be executed.
+    if hasattr(args, "dispatch_function"):
+        args.dispatch_function(args)
     else:
         parser.print_help()
 

Fișier diff suprimat deoarece este prea mare
+ 5 - 4
aphrodite/endpoints/kobold/klite.embd


+ 487 - 79
aphrodite/endpoints/llm.py

@@ -1,16 +1,23 @@
-from typing import List, Optional, Union
+from contextlib import contextmanager
+from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
 
-import torch
 from tqdm import tqdm
 from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
-from aphrodite.lora.request import LoRARequest
-from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
+from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import MultiModalData
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import Counter, deprecate_kwargs
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.args_tools import EngineArgs
+from aphrodite.inputs import (PromptInputs, TextPrompt, TokensPrompt,
+                              parse_and_batch_prompt)
+from aphrodite.lora.request import LoRARequest
+from aphrodite.modeling.guided_decoding import (
+    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
+from aphrodite.modeling.guided_decoding.guided_fields import LLMGuidedOptions
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+from aphrodite.transformers_utils.tokenizer import get_cached_tokenizer
 
 
 class LLM:
@@ -22,15 +29,14 @@ class LLM:
     this class generates texts from the model, using an intelligent batching
     mechanism and efficient memory management.
 
-    NOTE: This class is intended to be used for offline inference. For online
-    serving, use the `AsyncLLMEngine` class instead.
-    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
-
     Args:
         model: The name or path of a HuggingFace Transformers model.
         tokenizer: The name or path of a HuggingFace Transformers tokenizer.
         tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
             if available, and "slow" will always use the slow tokenizer.
+        skip_tokenizer_init: If true, skip initialization of tokenizer and
+            detokenizer. Expect valid prompt_token_ids and None for prompt
+            from the input.
         trust_remote_code: Trust remote code (e.g., from HuggingFace) when
             downloading the model and tokenizer.
         tensor_parallel_size: The number of GPUs to use for distributed
@@ -41,12 +47,15 @@ class LLM:
             However, if the `torch_dtype` in the config is `float32`, we will
             use `float16` instead.
         quantization: The method used to quantize the model weights. Currently,
-            we support "awq", "gptq", "quip" and "squeezellm". If None,
-            we first check the `quantization_config` attribute in the model
-            config file. If that is None, we assume the model weights are not
-            quantized and use `dtype` to determine the data type of the weights.
+            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
+            If None, we first check the `quantization_config` attribute in the
+            model config file. If that is None, we assume the model weights are
+            not quantized and use `dtype` to determine the data type of
+            the weights.
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id.
+        tokenizer_revision: The specific tokenizer version to use. It can be a
+            branch name, a tag name, or a commit id.
         seed: The seed to initialize the random number generator for sampling.
         gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
             reserve for the model weights, activations, and KV cache. Higher
@@ -58,49 +67,87 @@ class LLM:
             when their `best_of` sampling parameters are larger than 1. If all
             requests will have `best_of=1`, you can safely set this to 0.
             Otherwise, too small values may cause out-of-memory (OOM) errors.
+        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
+            the model weights. This virtually increases the GPU memory space
+            you can use to hold the model weights, at the cost of CPU-GPU data
+            transfer for every forward pass.
         enforce_eager: Whether to enforce eager execution. If True, we will
             disable CUDA graph and always execute the model in eager mode.
             If False, we will use CUDA graph and eager execution in hybrid.
         max_context_len_to_capture: Maximum context len covered by CUDA graphs.
+            When a sequence has context length larger than this, we fall back
+            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
+        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
             When a sequence has context length larger than this, we fall back
             to eager mode.
-        disable_custom_all_reduce: See ParallelConfig.
+        disable_custom_all_reduce: See ParallelConfig
+        **kwargs: Arguments for :class:`~aphrodite.EngineArgs`. (See
+            :ref:`engine_args`)
+    
+    Note:
+        This class is intended to be used for offline inference. For online
+        serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
     """
 
+    DEPRECATE_LEGACY: ClassVar[bool] = False
+    """A flag to toggle whether to deprecate the legacy generate/encode API."""
+
+    @classmethod
+    @contextmanager
+    def deprecate_legacy_api(cls):
+        cls.DEPRECATE_LEGACY = True
+
+        yield
+
+        cls.DEPRECATE_LEGACY = False
+
     def __init__(
         self,
         model: str,
         tokenizer: Optional[str] = None,
         tokenizer_mode: str = "auto",
+        skip_tokenizer_init: bool = False,
         trust_remote_code: bool = False,
         tensor_parallel_size: int = 1,
         dtype: str = "auto",
         quantization: Optional[str] = None,
         revision: Optional[str] = None,
+        tokenizer_revision: Optional[str] = None,
         seed: int = 0,
         gpu_memory_utilization: float = 0.9,
         swap_space: int = 4,
-        enforce_eager: bool = True,
-        max_context_len_to_capture: int = 8192,
+        cpu_offload_gb: float = 0,
+        enforce_eager: bool = False,
+        max_context_len_to_capture: Optional[int] = None,
+        max_seq_len_to_capture: int = 8192,
         disable_custom_all_reduce: bool = False,
         **kwargs,
     ) -> None:
         if "disable_log_stats" not in kwargs:
             kwargs["disable_log_stats"] = True
+        removed_vision_keys = ("image_token_id", "image_feature_size",
+                               "image_input_shape", "image_input_type")
+        if any(k in kwargs for k in removed_vision_keys):
+            raise TypeError(
+                "There is no need to pass vision-related arguments anymore.")
         engine_args = EngineArgs(
             model=model,
             tokenizer=tokenizer,
             tokenizer_mode=tokenizer_mode,
+            skip_tokenizer_init=skip_tokenizer_init,
             trust_remote_code=trust_remote_code,
             tensor_parallel_size=tensor_parallel_size,
             dtype=dtype,
             quantization=quantization,
             revision=revision,
+            tokenizer_revision=tokenizer_revision,
             seed=seed,
             gpu_memory_utilization=gpu_memory_utilization,
             swap_space=swap_space,
+            cpu_offload_gb=cpu_offload_gb,
             enforce_eager=enforce_eager,
             max_context_len_to_capture=max_context_len_to_capture,
+            max_seq_len_to_capture=max_seq_len_to_capture,
             disable_custom_all_reduce=disable_custom_all_reduce,
             **kwargs,
         )
@@ -109,115 +156,476 @@ class LLM:
 
     def get_tokenizer(
             self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
-        return self.llm_engine.tokenizer
+        return self.llm_engine.tokenizer.tokenizer
 
     def set_tokenizer(
         self,
         tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
     ) -> None:
-        self.llm_engine.tokenizer = tokenizer
+        # While CachedTokenizer is dynamic, have no choice but
+        # compare class name. Misjudgment will arise from
+        # user-defined tokenizer started with 'Cached'
+        if tokenizer.__class__.__name__.startswith("Cached"):
+            self.llm_engine.tokenizer.tokenizer = tokenizer
+        else:
+            self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
+                tokenizer)
 
+    @overload  # LEGACY: single (prompt + optional token ids)
     def generate(
         self,
-        prompts: Optional[Union[str, List[str]]] = None,
-        sampling_params: Optional[SamplingParams] = None,
+        prompts: str,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        prompt_token_ids: Optional[List[int]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (prompt + optional token ids)
+    def generate(
+        self,
+        prompts: List[str],
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
         prompt_token_ids: Optional[List[List[int]]] = None,
         use_tqdm: bool = True,
-        lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: single (token ids + optional prompt)
+    def generate(
+        self,
+        prompts: Optional[str] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        *,
+        prompt_token_ids: List[int],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (token ids + optional prompt)
+    def generate(
+        self,
+        prompts: Optional[List[str]] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        *,
+        prompt_token_ids: List[List[int]],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: single or multi token ids [pos-only]
+    def generate(
+        self,
+        prompts: None,
+        sampling_params: None,
+        prompt_token_ids: Union[List[int], List[List[int]]],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload
+    def generate(
+        self,
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
+        /,  # We may enable `inputs` keyword after removing the old API
+        *,
+        sampling_params: Optional[Union[SamplingParams,
+                                        Sequence[SamplingParams]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @deprecate_kwargs("prompts",
+                      "prompt_token_ids",
+                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+                      additional_message="Please use the 'inputs' parameter "
+                      "instead.")
+    def generate(
+        self,
+        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
+                       Optional[Union[str, List[str]]]] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        Sequence[SamplingParams]]] = None,
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+        guided_options_request: Optional[Union[LLMGuidedOptions,
+                                               GuidedDecodingRequest]] = None,
     ) -> List[RequestOutput]:
         """Generates the completions for the input prompts.
 
-        NOTE: This class automatically batches the given prompts, considering
+        This class automatically batches the given prompts, considering
         the memory constraint. For the best performance, put all of your prompts
         into a single list and pass it to this method.
 
         Args:
-            prompts: A list of prompts to generate completions for.
+            inputs: A list of inputs to generate completions for.
             sampling_params: The sampling parameters for text generation. If
-                None, we use the default sampling parameters.
-            prompt_token_ids: A list of token IDs for the prompts. If None, we
-                use the tokenizer to convert the prompts to token IDs.
+                None, we use the default sampling parameters. 
+                When it is a single value, it is applied to every prompt. 
+                When it is a list, the list must have the same length as the 
+                prompts and it is paired one by one with the prompt.
             use_tqdm: Whether to use tqdm to display the progress bar.
             lora_request: LoRA request to use for generation, if any.
-            multi_modal_data: Multi-modal data to use for generation, if any.
+            prompt_adapter_request: Prompt Adapter request to use for 
+                generation, if any.
 
         Returns:
-            A list of `RequestOutput` objects containing the generated
-            completions in the same order as the input prompts.
+            A list of `RequestOutput` objects containing the
+            generated completions in the same order as the input prompts.
+
+        Note:
+            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
+            considered legacy and may be deprecated in the future. You should
+            instead pass them via the ``inputs`` parameter.
         """
-        if prompts is None and prompt_token_ids is None:
-            raise ValueError("Either prompts or prompt_token_ids must be "
-                             "provided.")
-        if isinstance(prompts, str):
-            # Convert a single prompt to a list.
-            prompts = [prompts]
-        if prompts is not None and prompt_token_ids is not None and len(
-                prompts) != len(prompt_token_ids):
+        if self.llm_engine.model_config.embedding_mode:
             raise ValueError(
-                "The lengths of prompts and prompt_token_ids must "
-                "be the same.")
+                "LLM.generate() is only supported for generation models "
+                "(XForCausalLM).")
+
+        if prompt_token_ids is not None:
+            inputs = self._convert_v1_inputs(
+                prompts=cast(Optional[Union[str, List[str]]], prompts),
+                prompt_token_ids=prompt_token_ids,
+            )
+        else:
+            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
+
+        if isinstance(guided_options_request, dict):
+            if len(guided_options_request) > 1:
+                raise ValueError(
+                    "You can only use one guided decoding but multiple is "
+                    f"specified: {guided_options_request}")
+            guided_options_request = GuidedDecodingRequest(
+                **guided_options_request)
+
         if sampling_params is None:
             # Use default sampling params.
             sampling_params = SamplingParams()
 
-        if multi_modal_data:
-            multi_modal_data.data = multi_modal_data.data.to(torch.float16)
+        self._validate_and_add_requests(
+            inputs=inputs,
+            params=sampling_params,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+            guided_options=guided_options_request,
+        )
+
+        outputs = self._run_engine(use_tqdm=use_tqdm)
+        return AphroditeEngine.validate_outputs(outputs, RequestOutput)
 
-        # Add requests to the engine.
-        num_requests = len(prompts) if prompts is not None else len(
-            prompt_token_ids)
+    @overload  # LEGACY: single (prompt + optional token ids)
+    def encode(
+        self,
+        prompts: str,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        prompt_token_ids: Optional[List[int]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (prompt + optional token ids)
+    def encode(
+        self,
+        prompts: List[str],
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        prompt_token_ids: Optional[List[List[int]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: single (token ids + optional prompt)
+    def encode(
+        self,
+        prompts: Optional[str] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        *,
+        prompt_token_ids: List[int],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (token ids + optional prompt)
+    def encode(
+        self,
+        prompts: Optional[List[str]] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        *,
+        prompt_token_ids: List[List[int]],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: single or multi token ids [pos-only]
+    def encode(
+        self,
+        prompts: None,
+        pooling_params: None,
+        prompt_token_ids: Union[List[int], List[List[int]]],
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload
+    def encode(
+        self,
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
+        /,  # We may enable `inputs` keyword after removing the old API
+        *,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @deprecate_kwargs("prompts",
+                      "prompt_token_ids",
+                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+                      additional_message="Please use the 'inputs' parameter "
+                      "instead.")
+    def encode(
+        self,
+        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
+                       Optional[Union[str, List[str]]]] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        """Generates the completions for the input prompts.
+
+        This class automatically batches the given prompts, considering
+        the memory constraint. For the best performance, put all of your prompts
+        into a single list and pass it to this method.
+
+        Args:
+            inputs: The inputs to the LLM. You may pass a sequence of inputs for
+                batch inference. See
+                :class:`~aphrodite.inputs.PromptInputs`
+                for more details about the format of each input.
+            pooling_params: The pooling parameters for pooling. If None, we
+                use the default pooling parameters.
+            use_tqdm: Whether to use tqdm to display the progress bar.
+            lora_request: LoRA request to use for generation, if any.
+            prompt_adapter_request: Prompt Adapter request to use for 
+                generation, if any.
+
+        Returns:
+            A list of `EmbeddingRequestOutput` objects containing the
+            generated embeddings in the same order as the input prompts.
+
+        Note:
+            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
+            considered legacy and may be deprecated in the future. You should
+            instead pass them via the ``inputs`` parameter.
+        """
+        if not self.llm_engine.model_config.embedding_mode:
+            raise ValueError(
+                "LLM.encode() is only supported for embedding models (XModel)."
+            )
+
+        if prompt_token_ids is not None:
+            inputs = self._convert_v1_inputs(
+                prompts=cast(Optional[Union[str, List[str]]], prompts),
+                prompt_token_ids=prompt_token_ids,
+            )
+        else:
+            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
+
+        if pooling_params is None:
+            # Use default pooling params.
+            pooling_params = PoolingParams()
+
+        self._validate_and_add_requests(
+            inputs=inputs,
+            params=pooling_params,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request)
+
+        outputs = self._run_engine(use_tqdm=use_tqdm)
+        return AphroditeEngine.validate_outputs(outputs,
+                                                EmbeddingRequestOutput)
+
+    # LEGACY
+    def _convert_v1_inputs(
+        self,
+        prompts: Optional[Union[str, List[str]]],
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
+    ):
+        # skip_tokenizer_init is now checked in engine
+
+        if prompts is not None:
+            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
+        if prompt_token_ids is not None:
+            prompt_token_ids = [
+                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
+            ]
+
+        num_requests = None
+        if prompts is not None:
+            num_requests = len(prompts)
+        if prompt_token_ids is not None:
+            if (num_requests is not None
+                    and num_requests != len(prompt_token_ids)):
+                raise ValueError("The lengths of prompts and prompt_token_ids "
+                                 "must be the same.")
+
+            num_requests = len(prompt_token_ids)
+        if num_requests is None:
+            raise ValueError("Either prompts or prompt_token_ids must be "
+                             "provided.")
+
+        inputs: List[PromptInputs] = []
         for i in range(num_requests):
-            prompt = prompts[i] if prompts is not None else None
-            token_ids = None if prompt_token_ids is None else prompt_token_ids[
-                i]
+            if prompts is not None:
+                item = TextPrompt(prompt=prompts[i])
+            elif prompt_token_ids is not None:
+                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
+            else:
+                raise AssertionError
+
+            inputs.append(item)
+
+        return inputs
+
+    def _validate_and_add_requests(
+        self,
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
+        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
+                      Sequence[PoolingParams]],
+        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+        guided_options: Optional[GuidedDecodingRequest] = None,
+    ) -> None:
+        if isinstance(inputs, (str, dict)):
+            # Convert a single prompt to a list.
+            inputs = [inputs]
+
+        num_requests = len(inputs)
+
+        if isinstance(params, list) and len(params) != num_requests:
+            raise ValueError("The lengths of prompts and params "
+                             "must be the same.")
+        if isinstance(lora_request,
+                      list) and len(lora_request) != num_requests:
+            raise ValueError("The lengths of prompts and lora_request "
+                             "must be the same.")
+
+        if isinstance(params, list):
+            params = [
+                self._add_guided_processor(param, guided_options)
+                if isinstance(param, SamplingParams) else param
+                for param in params
+            ]
+        elif isinstance(params, SamplingParams):
+            params = self._add_guided_processor(params, guided_options)
+
+        # Add requests to the engine.
+        for i, request_inputs in enumerate(inputs):
             self._add_request(
-                prompt,
-                sampling_params,
-                token_ids,
-                lora_request=lora_request,
-                # Get ith image while maintaining the batch dim.
-                multi_modal_data=MultiModalData(
-                    type=multi_modal_data.type,
-                    data=multi_modal_data.data[i].unsqueeze(0))
-                if multi_modal_data else None,
+                request_inputs,
+                params[i] if isinstance(params, Sequence) else params,
+                lora_request=lora_request[i] if isinstance(
+                    lora_request, Sequence) else lora_request,
+                prompt_adapter_request=prompt_adapter_request,
             )
-        return self._run_engine(use_tqdm)
 
     def _add_request(
-        self,
-        prompt: Optional[str],
-        sampling_params: SamplingParams,
-        prompt_token_ids: Optional[List[int]],
-        lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+            self,
+            inputs: PromptInputs,
+            params: Union[SamplingParams, PoolingParams],
+            lora_request: Optional[Union[List[LoRARequest],
+                                         LoRARequest]] = None,
+            prompt_adapter_request: Optional[PromptAdapterRequest] = None
     ) -> None:
         request_id = str(next(self.request_counter))
-        self.llm_engine.add_request(request_id,
-                                    prompt,
-                                    sampling_params,
-                                    prompt_token_ids,
-                                    lora_request=lora_request,
-                                    multi_modal_data=multi_modal_data)
-
-    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
+        self.llm_engine.add_request(
+            request_id,
+            inputs,
+            params,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request)
+
+    def _add_guided_processor(
+            self,
+            params: SamplingParams,
+            guided_options: Optional[GuidedDecodingRequest] = None):
+        if guided_options:
+            if guided_options.guided_decoding_backend is None:
+                decoding_config = self.llm_engine.get_decoding_config()
+                guided_options.guided_decoding_backend = (
+                    decoding_config.guided_decoding_backend)
+            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
+                guided_options.guided_decoding_backend, guided_options,
+                self.get_tokenizer())
+            if guided_logits_processor:
+                if params.logits_processors is None:
+                    params.logits_processors = []
+                params.logits_processors.append(guided_logits_processor)
+        return params
+
+    def _run_engine(
+            self, *, use_tqdm: bool
+    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
         # Initialize tqdm.
         if use_tqdm:
             num_requests = self.llm_engine.get_num_unfinished_requests()
-            pbar = tqdm(total=num_requests, desc="Processed prompts")
+            pbar = tqdm(
+                total=num_requests,
+                desc="Processed prompts",
+                dynamic_ncols=True,
+                postfix=(f"est. speed input: {0:.2f} toks/s, "
+                         f"output: {0:.2f} toks/s"),
+            )
         # Run the engine.
-        outputs: List[RequestOutput] = []
+        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
+        total_in_toks = 0
+        total_out_toks = 0
         while self.llm_engine.has_unfinished_requests():
             step_outputs = self.llm_engine.step()
             for output in step_outputs:
                 if output.finished:
                     outputs.append(output)
                     if use_tqdm:
+                        if isinstance(output, RequestOutput):
+                            # Calculate tokens only for RequestOutput
+                            total_in_toks += len(output.prompt_token_ids)
+                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
+                            total_out_toks += sum(
+                                len(stp.token_ids) for stp in output.outputs)
+                            out_spd = total_out_toks / pbar.format_dict[
+                                "elapsed"]
+                            pbar.postfix = (
+                                f"est. speed input: {in_spd:.2f} toks/s, "
+                                f"output: {out_spd:.2f} toks/s")
                         pbar.update(1)
         if use_tqdm:
             pbar.close()
         # Sort the outputs by request ID.
         # This is necessary because some requests may be finished earlier than
         # its previous requests.
-        outputs = sorted(outputs, key=lambda x: int(x.request_id))
-        return outputs
+        return sorted(outputs, key=lambda x: int(x.request_id))

+ 40 - 0
aphrodite/endpoints/logger.py

@@ -0,0 +1,40 @@
+from typing import List, Optional, Union
+
+from loguru import logger
+
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.lora.request import LoRARequest
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+
+
+class RequestLogger:
+
+    def __init__(self, *, max_log_len: Optional[int]) -> None:
+        super().__init__()
+
+        self.max_log_len = max_log_len
+
+    def log_inputs(
+        self,
+        request_id: str,
+        prompt: Optional[str],
+        prompt_token_ids: Optional[List[int]],
+        params: Optional[Union[SamplingParams, PoolingParams]],
+        lora_request: Optional[LoRARequest],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> None:
+        max_log_len = self.max_log_len
+        if max_log_len is not None:
+            if prompt is not None:
+                prompt = prompt[:max_log_len]
+
+            if prompt_token_ids is not None:
+                prompt_token_ids = prompt_token_ids[:max_log_len]
+
+        logger.info(f"Received request {request_id}: "
+                    f"prompt: {prompt}, "
+                    f"params: {params}, "
+                    f"prompt_token_ids: {prompt_token_ids}, "
+                    f"lora_request: {lora_request}, "
+                    f"prompt_adapter_request: {prompt_adapter_request}.")

+ 263 - 164
aphrodite/endpoints/openai/api_server.py

@@ -3,44 +3,66 @@ import importlib
 import inspect
 import json
 import os
+import re
+from argparse import Namespace
 from contextlib import asynccontextmanager
 from http import HTTPStatus
-from typing import AsyncGenerator, List, Optional, Tuple
+from multiprocessing import Process
+from typing import AsyncGenerator, AsyncIterator, List, Set, Tuple
 
-import fastapi
-import uvicorn
-from fastapi import APIRouter, Header, Request
+import uvloop
+from fastapi import APIRouter, FastAPI, Request
 from fastapi.exceptions import RequestValidationError
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import (HTMLResponse, JSONResponse, Response,
                                StreamingResponse)
 from loguru import logger
 from prometheus_client import make_asgi_app
+from starlette.routing import Mount
 
-import aphrodite
-import aphrodite.endpoints.openai.embeddings as OAIembeddings
-from aphrodite.common.logger import UVICORN_LOG_CONFIG
+from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
-from aphrodite.common.utils import random_uuid
+from aphrodite.common.utils import (FlexibleArgumentParser, get_open_port,
+                                    random_uuid)
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.args import make_arg_parser
-from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest, CompletionRequest, EmbeddingsRequest,
-    EmbeddingsResponse, ErrorResponse, KAIGenerationInputSchema, Prompt)
+# yapf: disable
+from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
+                                                 ChatCompletionResponse,
+                                                 CompletionRequest,
+                                                 DetokenizeRequest,
+                                                 DetokenizeResponse,
+                                                 EmbeddingRequest,
+                                                 ErrorResponse,
+                                                 KAIGenerationInputSchema,
+                                                 TokenizeRequest,
+                                                 TokenizeResponse)
+from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
+from aphrodite.endpoints.openai.rpc.server import run_rpc_server
+# yapf: enable
 from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
-from aphrodite.endpoints.openai.serving_completions import \
-    OpenAIServingCompletion
+from aphrodite.endpoints.openai.serving_completions import (
+    OpenAIServingCompletion)
+from aphrodite.endpoints.openai.serving_embedding import OpenAIServingEmbedding
+from aphrodite.endpoints.openai.serving_tokenization import (
+    OpenAIServingTokenization)
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.engine.protocol import AsyncEngineClient
+from aphrodite.server import serve_http
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
-from aphrodite.endpoints.openai.serving_engine import LoRA
+from aphrodite.version import __version__ as APHRODITE_VERSION
 
 TIMEOUT_KEEP_ALIVE = 5  # seconds
-
-engine: Optional[AsyncAphrodite] = None
-engine_args: Optional[AsyncEngineArgs] = None
-openai_serving_chat: OpenAIServingChat = None
-openai_serving_completion: OpenAIServingCompletion = None
+APHRODITE_RPC_PORT = int(os.getenv("APHRODITE_RPC_PORT", '5570'))
+
+async_engine_client: AsyncEngineClient
+engine_args: AsyncEngineArgs
+openai_serving_chat: OpenAIServingChat
+openai_serving_completion: OpenAIServingCompletion
+openai_serving_embedding: OpenAIServingEmbedding
+openai_serving_tokenization: OpenAIServingTokenization
 router = APIRouter()
 kai_api = APIRouter()
 extra_api = APIRouter()
@@ -48,117 +70,129 @@ kobold_lite_ui = ""
 sampler_json = ""
 gen_cache: dict = {}
 
+_running_tasks: Set[asyncio.Task] = set()
+
+
+def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
+    return ModelConfig(model=model_name,
+                       tokenizer=model_name,
+                       tokenizer_mode="auto",
+                       trust_remote_code=trust_remote_code,
+                       seed=0,
+                       dtype="float16").embedding_mode
+
 
 @asynccontextmanager
-async def lifespan(app: fastapi.FastAPI):
+async def lifespan(app: FastAPI):
 
     async def _force_log():
         while True:
             await asyncio.sleep(10)
-            await engine.do_log_stats()
+            await async_engine_client.do_log_stats()
 
     if not engine_args.disable_log_stats:
-        asyncio.create_task(_force_log())
+        task = asyncio.create_task(_force_log())
+        _running_tasks.add(task)
+        task.add_done_callback(_running_tasks.remove)
 
     yield
 
 
-# Add prometheus asgi middleware to route /metrics requests
-metrics_app = make_asgi_app()
-router.mount("/metrics", metrics_app)
+@asynccontextmanager
+async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
+    # Context manager to handle async_engine_client lifecycle
+    # Ensures everything is shutdown and cleaned up on error/exit
+    global engine_args
+    engine_args = AsyncEngineArgs.from_cli_args(args)
+
+    # Backend itself still global for the silly lil' health handler
+    global async_engine_client
+
+    # If manually triggered or embedding model, use AsyncAphrodite in process.
+    # TODO: support embedding model via RPC.
+    if (model_is_embedding(args.model, args.trust_remote_code)
+            or args.disable_frontend_multiprocessing):
+        async_engine_client = AsyncAphrodite.from_engine_args(engine_args)
+        yield async_engine_client
+        return
+
+    # Otherwise, use the multiprocessing AsyncAphrodite.
+    else:
+        # Start RPCServer in separate process (holds the AsyncAphrodite).
+        port = get_open_port(APHRODITE_RPC_PORT)
+        rpc_server_process = Process(target=run_rpc_server,
+                                     args=(engine_args, port))
+        rpc_server_process.start()
+
+        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
+        async_engine_client = AsyncEngineRPCClient(port)
+        await async_engine_client.setup()
+
+        try:
+            yield async_engine_client
+        finally:
+            # Ensure rpc server process was terminated
+            rpc_server_process.terminate()
+
+            # Close all open connections to the backend
+            async_engine_client.close()
+
+            # Wait for server process to join
+            rpc_server_process.join()
+
+
+def mount_metrics(app: FastAPI):
+    # Add prometheus asgi middleware to route /metrics requests
+    metrics_route = Mount("/metrics", make_asgi_app())
+    # Workaround for 307 Redirect for /metrics
+    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
+    app.routes.append(metrics_route)
 
 
 @router.get("/health")
 async def health() -> Response:
     """Health check."""
-    await openai_serving_chat.engine.check_health()
-    await openai_serving_completion.engine.check_health()
+    await async_engine_client.check_health()
     return Response(status_code=200)
 
 
-@router.get("/v1/models")
-async def show_available_models(x_api_key: Optional[str] = Header(None)):
-    models = await openai_serving_chat.show_available_models()
-    return JSONResponse(content=models.model_dump())
-
-
 @router.post("/v1/tokenize")
-@router.post("/v1/token/encode")
-async def tokenize(request: Request,
-                   prompt: Prompt,
-                   x_api_key: Optional[str] = Header(None)):
-    tokenized = await openai_serving_chat.tokenize(prompt)
-    return JSONResponse(content=tokenized)
+async def tokenize(request: TokenizeRequest):
+    generator = await openai_serving_tokenization.create_tokenize(request)
+    if isinstance(generator, ErrorResponse):
+        return JSONResponse(content=generator.model_dump(),
+                            status_code=generator.code)
+    else:
+        assert isinstance(generator, TokenizeResponse)
+        return JSONResponse(content=generator.model_dump())
 
 
 @router.post("/v1/detokenize")
-@router.post("/v1/token/decode")
-async def detokenize(request: Request,
-                     token_ids: List[int],
-                     x_api_key: Optional[str] = Header(None)):
-    detokenized = await openai_serving_chat.detokenize(token_ids)
-    return JSONResponse(content=detokenized)
-
-
-@router.post("/v1/embeddings", response_model=EmbeddingsResponse)
-async def handle_embeddings(request: EmbeddingsRequest,
-                            x_api_key: Optional[str] = Header(None)):
-    input = request.input
-    if not input:
-        raise JSONResponse(
-            status_code=400,
-            content={"error": "Missing required argument input"})
-
-    model = request.model if request.model else None
-    response = await OAIembeddings.embeddings(input, request.encoding_format,
-                                              model)
-    return JSONResponse(response)
-
-
-@router.get("/version", description="Fetch the Aphrodite Engine version.")
-async def show_version(x_api_key: Optional[str] = Header(None)):
-    ver = {"version": aphrodite.__version__}
-    return JSONResponse(content=ver)
-
-
-@router.get("/v1/samplers")
-async def show_samplers(x_api_key: Optional[str] = Header(None)):
-    """Get the available samplers."""
-    global sampler_json
-    if not sampler_json:
-        jsonpath = os.path.dirname(os.path.abspath(__file__))
-        samplerpath = os.path.join(jsonpath, "./samplers.json")
-        samplerpath = os.path.normpath(samplerpath)  # Normalize the path
-        if os.path.exists(samplerpath):
-            with open(samplerpath, "r") as f:
-                sampler_json = json.load(f)
-        else:
-            logger.error("Sampler JSON not found at " + samplerpath)
-    return sampler_json
+async def detokenize(request: DetokenizeRequest):
+    generator = await openai_serving_tokenization.create_detokenize(request)
+    if isinstance(generator, ErrorResponse):
+        return JSONResponse(content=generator.model_dump(),
+                            status_code=generator.code)
+    else:
+        assert isinstance(generator, DetokenizeResponse)
+        return JSONResponse(content=generator.model_dump())
 
 
-@router.post("/v1/lora/load")
-async def load_lora(lora: LoRA, x_api_key: Optional[str] = Header(None)):
-    openai_serving_chat.add_lora(lora)
-    openai_serving_completion.add_lora(lora)
-    if engine_args.enable_lora is False:
-        logger.error("LoRA is not enabled in the engine. "
-                     "Please start the server with the "
-                     "--enable-lora flag.")
-    return JSONResponse(content={"result": "success"})
+@router.get("/v1/models")
+async def show_available_models():
+    models = await openai_serving_completion.show_available_models()
+    return JSONResponse(content=models.model_dump())
 
 
-@router.delete("/v1/lora/unload")
-async def unload_lora(lora_name: str, x_api_key: Optional[str] = Header(None)):
-    openai_serving_chat.remove_lora(lora_name)
-    openai_serving_completion.remove_lora(lora_name)
-    return JSONResponse(content={"result": "success"})
+@router.get("/version")
+async def show_version():
+    ver = {"version": APHRODITE_VERSION}
+    return JSONResponse(content=ver)
 
 
 @router.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest,
-                                 raw_request: Request,
-                                 x_api_key: Optional[str] = Header(None)):
+                                 raw_request: Request):
     generator = await openai_serving_chat.create_chat_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -168,13 +202,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
         return StreamingResponse(content=generator,
                                  media_type="text/event-stream")
     else:
+        assert isinstance(generator, ChatCompletionResponse)
         return JSONResponse(content=generator.model_dump())
 
 
 @router.post("/v1/completions")
-async def create_completion(request: CompletionRequest,
-                            raw_request: Request,
-                            x_api_key: Optional[str] = Header(None)):
+async def create_completion(request: CompletionRequest, raw_request: Request):
     generator = await openai_serving_completion.create_completion(
         request, raw_request)
     if isinstance(generator, ErrorResponse):
@@ -187,6 +220,17 @@ async def create_completion(request: CompletionRequest,
         return JSONResponse(content=generator.model_dump())
 
 
+@router.post("/v1/embeddings")
+async def create_embedding(request: EmbeddingRequest, raw_request: Request):
+    generator = await openai_serving_embedding.create_embedding(
+        request, raw_request)
+    if isinstance(generator, ErrorResponse):
+        return JSONResponse(content=generator.model_dump(),
+                            status_code=generator.code)
+    else:
+        return JSONResponse(content=generator.model_dump())
+
+
 # ============ KoboldAI API ============ #
 
 
@@ -226,18 +270,12 @@ def prepare_engine_payload(
         kai_payload.n = 1
         kai_payload.top_p = 1.0
         kai_payload.top_k = -1
-    if kai_payload.dynatemp_range is not None:
-        dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
-        dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range
 
     sampling_params = SamplingParams(
         n=kai_payload.n,
         best_of=kai_payload.n,
         repetition_penalty=kai_payload.rep_pen,
         temperature=kai_payload.temperature,
-        dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
-        dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
-        dynatemp_exponent=kai_payload.dynatemp_exponent,
         smoothing_factor=kai_payload.smoothing_factor,
         smoothing_curve=kai_payload.smoothing_curve,
         tfs=kai_payload.tfs,
@@ -248,9 +286,6 @@ def prepare_engine_payload(
         typical_p=kai_payload.typical,
         eta_cutoff=kai_payload.eta_cutoff,
         epsilon_cutoff=kai_payload.eps_cutoff,
-        mirostat_mode=kai_payload.mirostat,
-        mirostat_tau=kai_payload.mirostat_tau,
-        mirostat_eta=kai_payload.mirostat_eta,
         stop=kai_payload.stop_sequence,
         include_stop_str_in_output=kai_payload.include_stop_str_in_output,
         custom_token_bans=badwordsids
@@ -269,8 +304,14 @@ def prepare_engine_payload(
 @kai_api.post("/generate")
 async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    result_generator = engine.generate(None, sampling_params,
-                                       kai_payload.genkey, input_tokens)
+    result_generator = async_engine_client.generate(
+        {
+            "prompt": kai_payload.prompt,
+            "prompt_token_ids": input_tokens,
+        },
+        sampling_params,
+        kai_payload.genkey,
+    )
 
     final_res: RequestOutput = None
     previous_output = ""
@@ -294,8 +335,14 @@ async def generate_stream(
         kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
 
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    results_generator = engine.generate(None, sampling_params,
-                                        kai_payload.genkey, input_tokens)
+    results_generator = async_engine_client.generate(
+        {
+            "prompt": kai_payload.prompt,
+            "prompt_token_ids": input_tokens,
+        },
+        sampling_params,
+        kai_payload.genkey,
+    )
 
     async def stream_kobold() -> AsyncGenerator[bytes, None]:
         previous_output = ""
@@ -332,7 +379,7 @@ async def abort_generation(request: Request):
     try:
         request_dict = await request.json()
         if "genkey" in request_dict:
-            await engine.abort(request_dict["genkey"])
+            await async_engine_client.abort(request_dict["genkey"])
     except json.JSONDecodeError:
         pass
 
@@ -340,13 +387,11 @@ async def abort_generation(request: Request):
 
 
 @extra_api.post("/tokencount")
-async def count_tokens(request: Request):
+async def count_tokens(request: TokenizeRequest):
     """Tokenize string and return token count"""
 
-    request_dict = await request.json()
-    tokenizer_result = await openai_serving_chat.tokenize(
-        Prompt(**request_dict))
-    return JSONResponse({"value": tokenizer_result["value"]})
+    generator = await openai_serving_tokenization.create_tokenize(request)
+    return JSONResponse({"value": generator.model_dump()["tokens"]})
 
 
 @kai_api.get("/info/version")
@@ -357,7 +402,7 @@ async def get_version():
 
 @kai_api.get("/model")
 async def get_model():
-    return JSONResponse({"result": f"aphrodite/{served_model}"})
+    return JSONResponse({"result": f"aphrodite/{served_model_names[0]}"})
 
 
 @kai_api.get("/config/soft_prompts_list")
@@ -423,9 +468,13 @@ async def get_kobold_lite_ui():
 # ============ KoboldAI API ============ #
 
 
-def build_app(args):
-    app = fastapi.FastAPI(lifespan=lifespan)
+def build_app(args: Namespace) -> FastAPI:
+    app = FastAPI(lifespan=lifespan)
     app.include_router(router)
+    # Add prometheus asgi middleware to route /metrics requests
+    route = Mount("/metrics", make_asgi_app())
+    route.path_regex = re.compile('^/metrics(?P<path>.*)$')
+    app.routes.append(route)
     app.root_path = args.root_path
     if args.launch_kobold_api:
         logger.warning("Launching Kobold API server in addition to OpenAI. "
@@ -437,6 +486,8 @@ def build_app(args):
                            include_in_schema=False)
         app.include_router(extra_api, prefix="/api/extra")
 
+    mount_metrics(app)
+
     app.add_middleware(
         CORSMiddleware,
         allow_origins=args.allowed_origins,
@@ -478,12 +529,6 @@ def build_app(args):
             auth_header = request.headers.get("Authorization")
             api_key_header = request.headers.get("x-api-key")
 
-            if request.url.path.startswith("/v1/lora"):
-                if admin_key is not None and api_key_header == admin_key:
-                    return await call_next(request)
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-
             if auth_header != "Bearer " + token and api_key_header != token:
                 return JSONResponse(content={"error": "Unauthorized"},
                                     status_code=401)
@@ -503,60 +548,114 @@ def build_app(args):
     return app
 
 
-def run_server(args):
+async def init_app(
+    async_engine_client: AsyncEngineClient,
+    args: Namespace,
+) -> FastAPI:
     app = build_app(args)
 
     logger.debug(f"args: {args}")
 
-    global engine, engine_args, openai_serving_chat, openai_serving_completion,\
-        tokenizer, served_model
+    global served_model_names
     if args.served_model_name is not None:
-        served_model = args.served_model_name
+        served_model_names = args.served_model_name
     else:
-        served_model = args.model
+        served_model_names = [args.model]
+
+    if args.uvloop:
+        uvloop.install()
+
+    global tokenizer
+
+    model_config = await async_engine_client.get_model_config()
+
+    if args.disable_log_requests:
+        request_logger = None
+    else:
+        request_logger = RequestLogger(max_log_len=args.max_log_len)
+
+    global openai_serving_chat
+    global openai_serving_completion
+    global openai_serving_embedding
+    global openai_serving_tokenization
+
+    openai_serving_chat = OpenAIServingChat(
+        async_engine_client,
+        model_config,
+        served_model_names,
+        args.response_role,
+        lora_modules=args.lora_modules,
+        prompt_adapters=args.prompt_adapters,
+        request_logger=request_logger,
+        chat_template=args.chat_template,
+        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+    )
+    openai_serving_completion = OpenAIServingCompletion(
+        async_engine_client,
+        model_config,
+        served_model_names,
+        lora_modules=args.lora_modules,
+        prompt_adapters=args.prompt_adapters,
+        request_logger=request_logger,
+        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+    )
+    openai_serving_embedding = OpenAIServingEmbedding(
+        async_engine_client,
+        model_config,
+        served_model_names,
+        request_logger=request_logger,
+    )
+    openai_serving_tokenization = OpenAIServingTokenization(
+        async_engine_client,
+        model_config,
+        served_model_names,
+        lora_modules=args.lora_modules,
+        request_logger=request_logger,
+        chat_template=args.chat_template,
+    )
+    app.root_path = args.root_path
 
-    engine_args = AsyncEngineArgs.from_cli_args(args)
-    engine = AsyncAphrodite.from_engine_args(engine_args)
     tokenizer = get_tokenizer(
-        engine_args.tokenizer,
+        tokenizer_name=engine_args.tokenizer,
         tokenizer_mode=engine_args.tokenizer_mode,
         trust_remote_code=engine_args.trust_remote_code,
         revision=engine_args.revision,
     )
 
-    chat_template = args.chat_template
-    if chat_template is None and tokenizer.chat_template is not None:
-        chat_template = tokenizer.chat_template
+    if args.launch_kobold_api:
+        _set_badwords(tokenizer, model_config.hf_config)
 
-    openai_serving_chat = OpenAIServingChat(engine, served_model,
-                                            args.response_role,
-                                            args.lora_modules,
-                                            args.chat_template)
-    openai_serving_completion = OpenAIServingCompletion(
-        engine, served_model, args.lora_modules)
-    engine_model_config = asyncio.run(engine.get_model_config())
+    return app
 
-    if args.launch_kobold_api:
-        _set_badwords(tokenizer, engine_model_config.hf_config)
-    try:
-        uvicorn.run(app,
-                    host=args.host,
-                    port=args.port,
-                    log_level="info",
-                    timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
-                    ssl_keyfile=args.ssl_keyfile,
-                    ssl_certfile=args.ssl_certfile,
-                    log_config=UVICORN_LOG_CONFIG)
-    except KeyboardInterrupt:
-        logger.info("API server stopped by user. Exiting.")
-    except asyncio.exceptions.CancelledError:
-        logger.info("API server stopped due to a cancelled request. Exiting.")
+
+async def run_server(args, **uvicorn_kwargs) -> None:
+
+    async with build_async_engine_client(args) as async_engine_client:
+        app = await init_app(async_engine_client, args)
+
+        shutdown_task = await serve_http(
+            app,
+            host=args.host,
+            port=args.port,
+            log_level=args.uvicorn_log_level,
+            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+            ssl_keyfile=args.ssl_keyfile,
+            ssl_certfile=args.ssl_certfile,
+            ssl_ca_certs=args.ssl_ca_certs,
+            ssl_cert_reqs=args.ssl_cert_reqs,
+            **uvicorn_kwargs,
+        )
+
+    # NB: Await server shutdown only after the backend context is exited
+    await shutdown_task
 
 
 if __name__ == "__main__":
     # NOTE:
     # This section should be in sync with aphrodite/endpoints/cli.py
     # for CLI entrypoints.
-    parser = make_arg_parser()
+    parser = FlexibleArgumentParser(
+        description="Aphrodite OpenAI-Compatible RESTful API Server")
+    parser = make_arg_parser(parser)
     args = parser.parse_args()
-    run_server(args)
+    asyncio.run(run_server(args))

+ 82 - 42
aphrodite/endpoints/openai/args.py

@@ -1,14 +1,17 @@
 """
-This file contains the command line arguments for Aphrodite's
+This file contains the command line arguments for the Aphrodite's
 OpenAI-compatible server. It is kept in a separate file for documentation
 purposes.
 """
 
 import argparse
 import json
+import ssl
 
+from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       PromptAdapterPath)
 from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.endpoints.openai.serving_engine import LoRA
 
 
 class LoRAParserAction(argparse.Action):
@@ -17,16 +20,29 @@ class LoRAParserAction(argparse.Action):
         lora_list = []
         for item in values:
             name, path = item.split('=')
-            lora_list.append(LoRA(name, path))
+            lora_list.append(LoRAModulePath(name, path))
         setattr(namespace, self.dest, lora_list)
 
 
-def make_arg_parser(parser=None):
-    if parser is None:
-        parser = argparse.ArgumentParser(
-            description="Aphrodite OpenAI-Compatible RESTful API server.")
+class PromptAdapterParserAction(argparse.Action):
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        adapter_list = []
+        for item in values:
+            name, path = item.split('=')
+            adapter_list.append(PromptAdapterPath(name, path))
+        setattr(namespace, self.dest, adapter_list)
+
+
+def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
     parser.add_argument("--host", type=str, default=None, help="host name")
     parser.add_argument("--port", type=int, default=2242, help="port number")
+    parser.add_argument(
+        "--uvicorn-log-level",
+        type=str,
+        default="info",
+        choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
+        help="log level for uvicorn")
     parser.add_argument("--allow-credentials",
                         action="store_true",
                         help="allow credentials")
@@ -42,45 +58,32 @@ def make_arg_parser(parser=None):
                         type=json.loads,
                         default=["*"],
                         help="allowed headers")
-    parser.add_argument(
-        "--api-keys",
-        type=str,
-        default=None,
-        help=
-        "If provided, the server will require this key to be presented in the "
-        "header.")
-    parser.add_argument(
-        "--admin-key",
-        type=str,
-        default=None,
-        help=
-        "If provided, the server will require this key to be presented in the "
-        "header for admin operations.")
-    parser.add_argument(
-        "--launch-kobold-api",
-        action="store_true",
-        help=
-        "Launch the Kobold API server in addition to the OpenAI API server.")
-    parser.add_argument("--max-length",
-                        type=int,
-                        default=256,
-                        help="The maximum length of the generated response. "
-                        "For use with Kobold Horde.")
-    parser.add_argument("--served-model-name",
+    parser.add_argument("--api-keys",
                         type=str,
                         default=None,
-                        help="The model name used in the API. If not "
-                        "specified, the model name will be the same as "
-                        "the huggingface name.")
+                        help="If provided, the server will require this key "
+                        "to be presented in the header.")
+    parser.add_argument("--admin-key",
+                        type=str,
+                        default=None,
+                        help="If provided, the server will require this key "
+                        "to be presented in the header for admin operations.")
     parser.add_argument(
         "--lora-modules",
         type=str,
         default=None,
         nargs='+',
         action=LoRAParserAction,
-        help=
-        "LoRA module configurations in the format name=path. Multiple modules "
-        "can be specified.")
+        help="LoRA module configurations in the format name=path. "
+        "Multiple modules can be specified.")
+    parser.add_argument(
+        "--prompt-adapters",
+        type=str,
+        default=None,
+        nargs='+',
+        action=PromptAdapterParserAction,
+        help="Prompt adapter configurations in the format name=path. "
+        "Multiple adapters can be specified.")
     parser.add_argument("--chat-template",
                         type=str,
                         default=None,
@@ -100,6 +103,16 @@ def make_arg_parser(parser=None):
                         type=str,
                         default=None,
                         help="The file path to the SSL cert file")
+    parser.add_argument("--ssl-ca-certs",
+                        type=str,
+                        default=None,
+                        help="The CA certificates file")
+    parser.add_argument(
+        "--ssl-cert-reqs",
+        type=int,
+        default=int(ssl.CERT_NONE),
+        help="Whether client certificate is required (see stdlib ssl module's)"
+    )
     parser.add_argument(
         "--root-path",
         type=str,
@@ -113,10 +126,37 @@ def make_arg_parser(parser=None):
         help="Additional ASGI middleware to apply to the app. "
         "We accept multiple --middleware arguments. "
         "The value should be an import path. "
-        "If a function is provided, Aphrodite will add it to the server using "
-        "@app.middleware('http'). "
-        "If a class is provided, Aphrodite will add it to the server using "
-        "app.add_middleware(). ")
+        "If a function is provided, Aphrodite will add it to the server "
+        "using @app.middleware('http'). "
+        "If a class is provided, Aphrodite will add it to the server "
+        "using app.add_middleware(). ")
+    parser.add_argument(
+        "--launch-kobold-api",
+        action="store_true",
+        help="Launch the Kobold API server alongside the OpenAI server")
+    parser.add_argument("--max-log-len",
+                        type=int,
+                        default=0,
+                        help="Max number of prompt characters or prompt "
+                        "ID numbers being printed in log."
+                        "\n\nDefault: 0")
+    parser.add_argument(
+        "--return-tokens-as-token-ids",
+        action="store_true",
+        help="When --max-logprobs is specified, represents single tokens as"
+        "strings of the form 'token_id:{token_id}' so that tokens that"
+        "are not JSON-encodable can be identified.")
+    parser.add_argument(
+        "--disable-frontend-multiprocessing",
+        action="store_true",
+        help="If specified, will run the OpenAI frontend server in the same "
+        "process as the model serving engine.")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser
+
+
+def create_parser_for_docs() -> FlexibleArgumentParser:
+    parser_for_docs = FlexibleArgumentParser(
+        prog="-m aphrodite.endpoints.openai.api_server")
+    return make_arg_parser(parser_for_docs)

+ 0 - 144
aphrodite/endpoints/openai/embeddings.py

@@ -1,144 +0,0 @@
-"""
-This file is derived from
-[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
-and modified.
-The changes introduced are: Suppression of progress bar,
-typing/pydantic classes moved into this file,
-embeddings function declared async.
-"""
-
-import os
-import base64
-import numpy as np
-from transformers import AutoModel
-
-embeddings_params_initialized = False
-
-
-def initialize_embedding_params():
-    '''
-    using 'lazy loading' to avoid circular import
-    so this function will be executed only once
-    '''
-    global embeddings_params_initialized
-    if not embeddings_params_initialized:
-
-        global st_model, embeddings_model, embeddings_device
-
-        st_model = os.environ.get("OPENAI_EMBEDDING_MODEL",
-                                  'all-mpnet-base-v2')
-        embeddings_model = None
-        # OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
-        # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
-        # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
-        # hpu, mtia, privateuseone
-        embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu')
-        if embeddings_device.lower() == 'auto':
-            embeddings_device = None
-
-        embeddings_params_initialized = True
-
-
-def load_embedding_model(model: str):
-    try:
-        from sentence_transformers import SentenceTransformer
-    except ModuleNotFoundError:
-        print("The sentence_transformers module has not been found. " +
-              "Please install it manually with " +
-              "pip install -U sentence-transformers.")
-        raise ModuleNotFoundError from None
-
-    initialize_embedding_params()
-    global embeddings_device, embeddings_model
-    try:
-        print(f"Try embedding model: {model} on {embeddings_device}")
-        if 'jina-embeddings' in model:
-            # trust_remote_code is needed to use the encode method
-            embeddings_model = AutoModel.from_pretrained(
-                model, trust_remote_code=True)
-            embeddings_model = embeddings_model.to(embeddings_device)
-        else:
-            embeddings_model = SentenceTransformer(
-                model,
-                device=embeddings_device,
-            )
-
-        print(f"Loaded embedding model: {model}")
-    except Exception as e:
-        embeddings_model = None
-        raise Exception(f"Error: Failed to load embedding model: {model}",
-                        internal_message=repr(e)) from None
-
-
-def get_embeddings_model():
-    initialize_embedding_params()
-    global embeddings_model, st_model
-    if st_model and not embeddings_model:
-        load_embedding_model(st_model)  # lazy load the model
-
-    return embeddings_model
-
-
-def get_embeddings_model_name() -> str:
-    initialize_embedding_params()
-    global st_model
-    return st_model
-
-
-def get_embeddings(input: list) -> np.ndarray:
-    model = get_embeddings_model()
-    embedding = model.encode(input,
-                             convert_to_numpy=True,
-                             normalize_embeddings=True,
-                             convert_to_tensor=False,
-                             show_progress_bar=False)
-    return embedding
-
-
-async def embeddings(input: list,
-                     encoding_format: str,
-                     model: str = None) -> dict:
-    if model is None:
-        model = st_model
-    else:
-        load_embedding_model(model)
-
-    embeddings = get_embeddings(input)
-    if encoding_format == "base64":
-        data = [{
-            "object": "embedding",
-            "embedding": float_list_to_base64(emb),
-            "index": n
-        } for n, emb in enumerate(embeddings)]
-    else:
-        data = [{
-            "object": "embedding",
-            "embedding": emb.tolist(),
-            "index": n
-        } for n, emb in enumerate(embeddings)]
-
-    response = {
-        "object": "list",
-        "data": data,
-        "model": st_model if model is None else model,
-        "usage": {
-            "prompt_tokens": 0,
-            "total_tokens": 0,
-        }
-    }
-    return response
-
-
-def float_list_to_base64(float_array: np.ndarray) -> str:
-    # Convert the list to a float32 array that the OpenAPI client expects
-    # float_array = np.array(float_list, dtype="float32")
-
-    # Get raw bytes
-    bytes_array = float_array.tobytes()
-
-    # Encode bytes into base64
-    encoded_bytes = base64.b64encode(bytes_array)
-
-    # Turn raw base64 encoded bytes into ASCII
-    ascii_string = encoded_bytes.decode('ascii')
-    return ascii_string

+ 83 - 0
aphrodite/endpoints/openai/logits_processors.py

@@ -0,0 +1,83 @@
+from functools import lru_cache, partial
+from typing import Dict, FrozenSet, Iterable, List, Optional, Union
+
+import torch
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.sampling_params import LogitsProcessorFunc
+
+
+class AllowedTokenIdsLogitsProcessor:
+    """Logits processor for constraining generated tokens to a
+    specific set of token ids."""
+
+    def __init__(self, allowed_ids: Iterable[int]):
+        self.allowed_ids: Optional[List[int]] = list(allowed_ids)
+        self.mask: Optional[torch.Tensor] = None
+
+    def __call__(self, token_ids: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        if self.mask is None:
+            self.mask = torch.ones((logits.shape[-1], ),
+                                   dtype=torch.bool,
+                                   device=logits.device)
+            self.mask[self.allowed_ids] = False
+            self.allowed_ids = None
+        logits.masked_fill_(self.mask, float("-inf"))
+        return logits
+
+
+@lru_cache(maxsize=32)
+def _get_allowed_token_ids_logits_processor(
+    allowed_token_ids: FrozenSet[int],
+    vocab_size: int,
+) -> LogitsProcessorFunc:
+    if not allowed_token_ids:
+        raise ValueError("Empty allowed_token_ids provided")
+    if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
+        raise ValueError("allowed_token_ids contains "
+                         "out-of-vocab token id")
+    return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
+
+
+def logit_bias_logits_processor(logit_bias: Dict[str,
+                                                 float], token_ids: List[int],
+                                logits: torch.Tensor) -> torch.Tensor:
+    for token_id, bias in logit_bias.items():
+        logits[token_id] += bias
+    return logits
+
+
+def get_logits_processors(
+        logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
+        allowed_token_ids: Optional[List[int]],
+        tokenizer: PreTrainedTokenizer) -> List[LogitsProcessorFunc]:
+    logits_processors = []
+    if logit_bias:
+        try:
+            # Convert token_id to integer
+            # Clamp the bias between -100 and 100 per OpenAI API spec
+            clamped_logit_bias: Dict[int, float] = {
+                int(token_id): min(100.0, max(-100.0, bias))
+                for token_id, bias in logit_bias.items()
+            }
+        except ValueError as exc:
+            raise ValueError(
+                "Found token_id in logit_bias that is not "
+                "an integer or string representing an integer") from exc
+
+        # Check if token_id is within the vocab size
+        for token_id, bias in clamped_logit_bias.items():
+            if token_id < 0 or token_id >= tokenizer.vocab_size:
+                raise ValueError("token_id in logit_bias contains "
+                                 "out-of-vocab token id")
+
+        logits_processors.append(
+            partial(logit_bias_logits_processor, clamped_logit_bias))
+
+    if allowed_token_ids is not None:
+        logits_processors.append(
+            _get_allowed_token_ids_logits_processor(
+                frozenset(allowed_token_ids), tokenizer.vocab_size))
+
+    return logits_processors

+ 510 - 216
aphrodite/endpoints/openai/protocol.py

@@ -1,17 +1,27 @@
 # Adapted from
 # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
 import time
-from typing import Dict, List, Literal, Optional, Union
+from typing import Any, Dict, List, Literal, Optional, Union
 
-from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
+import torch
+from pydantic import (BaseModel, ConfigDict, Field, model_validator,
                       root_validator)
+from transformers import PreTrainedTokenizer
+from typing_extensions import Annotated
 
-from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import (LogitsProcessorFunc,
+                                              SamplingParams)
 from aphrodite.common.utils import random_uuid
-from aphrodite.common.logits_processor import BiasLogitsProcessor
+from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
+from aphrodite.endpoints.openai.logits_processors import get_logits_processors
 
 
-class ErrorResponse(BaseModel):
+class OpenAIBaseModel(BaseModel):
+    model_config = ConfigDict(extra="ignore")
+
+
+class ErrorResponse(OpenAIBaseModel):
     object: str = "error"
     message: str
     type: str
@@ -19,7 +29,7 @@ class ErrorResponse(BaseModel):
     code: int
 
 
-class ModelPermission(BaseModel):
+class ModelPermission(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
     object: str = "model_permission"
     created: int = Field(default_factory=lambda: int(time.time()))
@@ -34,143 +44,250 @@ class ModelPermission(BaseModel):
     is_blocking: bool = False
 
 
-class ModelCard(BaseModel):
+class ModelCard(OpenAIBaseModel):
     id: str
     object: str = "model"
     created: int = Field(default_factory=lambda: int(time.time()))
     owned_by: str = "pygmalionai"
     root: Optional[str] = None
     parent: Optional[str] = None
+    max_model_len: Optional[int] = None
     permission: List[ModelPermission] = Field(default_factory=list)
 
 
-class ModelList(BaseModel):
+class ModelList(OpenAIBaseModel):
     object: str = "list"
     data: List[ModelCard] = Field(default_factory=list)
 
 
-class UsageInfo(BaseModel):
+class UsageInfo(OpenAIBaseModel):
     prompt_tokens: int = 0
     total_tokens: int = 0
     completion_tokens: Optional[int] = 0
 
 
-class ResponseFormat(BaseModel):
+class ResponseFormat(OpenAIBaseModel):
     # type must be "json_object" or "text"
-    type: str = Literal["text", "json_object"]
+    type: Literal["text", "json_object"]
+
+
+class StreamOptions(OpenAIBaseModel):
+    include_usage: Optional[bool] = True
+    continuous_usage_stats: Optional[bool] = True
+
+
+class FunctionDefinition(OpenAIBaseModel):
+    name: str
+    description: Optional[str] = None
+    parameters: Optional[Dict[str, Any]] = None
+
+
+class ChatCompletionToolsParam(OpenAIBaseModel):
+    type: Literal["function"] = "function"
+    function: FunctionDefinition
+
 
+class ChatCompletionNamedFunction(OpenAIBaseModel):
+    name: str
 
-class ChatCompletionRequest(BaseModel):
+
+class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
+    function: ChatCompletionNamedFunction
+    type: Literal["function"] = "function"
+
+
+class ChatCompletionRequest(OpenAIBaseModel):
+    # Ordered by official OpenAI API documentation
+    # https://platform.openai.com/docs/api-reference/chat/create
+    messages: List[ChatCompletionMessageParam]
     model: str
-    messages: List[Dict[str, str]]
-    temperature: Optional[float] = 0.7
-    top_p: Optional[float] = 1.0
-    tfs: Optional[float] = 1.0
-    eta_cutoff: Optional[float] = 0.0
-    epsilon_cutoff: Optional[float] = 0.0
-    typical_p: Optional[float] = 1.0
-    n: Optional[int] = 1
+    frequency_penalty: Optional[float] = 0.0
+    logit_bias: Optional[Dict[str, float]] = None
+    logprobs: Optional[bool] = False
+    top_logprobs: Optional[int] = 0
     max_tokens: Optional[int] = None
-    min_tokens: Optional[int] = 0
-    seed: Optional[int] = None
+    n: Optional[int] = 1
+    presence_penalty: Optional[float] = 0.0
+    response_format: Optional[ResponseFormat] = None
+    seed: Optional[int] = Field(None,
+                                ge=torch.iinfo(torch.long).min,
+                                le=torch.iinfo(torch.long).max)
     stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
-    include_stop_str_in_output: Optional[bool] = False
     stream: Optional[bool] = False
-    logprobs: Optional[bool] = False
-    top_logprobs: Optional[int] = None
-    presence_penalty: Optional[float] = 0.0
-    frequency_penalty: Optional[float] = 0.0
-    repetition_penalty: Optional[float] = 1.0
-    logit_bias: Optional[Dict[str, float]] = None
+    stream_options: Optional[StreamOptions] = None
+    temperature: Optional[float] = 0.7
+    top_p: Optional[float] = 1.0
+    tools: Optional[List[ChatCompletionToolsParam]] = None
+    tool_choice: Optional[Union[Literal["none"],
+                                ChatCompletionNamedToolChoiceParam]] = "none"
     user: Optional[str] = None
+
+    # doc: begin-chat-completion-sampling-params
     best_of: Optional[int] = None
+    use_beam_search: Optional[bool] = False
     top_k: Optional[int] = -1
-    top_a: Optional[float] = 0.0
     min_p: Optional[float] = 0.0
-    mirostat_mode: Optional[int] = 0
-    mirostat_tau: Optional[float] = 0.0
-    mirostat_eta: Optional[float] = 0.0
-    dynatemp_min: Optional[float] = 0.0
-    dynatemp_max: Optional[float] = 0.0
-    dynatemp_exponent: Optional[float] = 1.0
+    top_a: Optional[float] = 0.0
+    tfs: Optional[float] = 1.0
+    eta_cutoff: Optional[float] = 0.0
+    epsilon_cutoff: Optional[float] = 0.0
+    typical_p: Optional[float] = 1.0
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
+    repetition_penalty: Optional[float] = 1.0
+    length_penalty: Optional[float] = 1.0
+    early_stopping: Optional[bool] = False
     ignore_eos: Optional[bool] = False
-    use_beam_search: Optional[bool] = False
-    prompt_logprobs: Optional[int] = None
+    min_tokens: Optional[int] = 0
     stop_token_ids: Optional[List[int]] = Field(default_factory=list)
-    custom_token_bans: Optional[List[int]] = Field(default_factory=list)
     skip_special_tokens: Optional[bool] = True
     spaces_between_special_tokens: Optional[bool] = True
-    add_generation_prompt: Optional[bool] = True
-    echo: Optional[bool] = False
-    length_penalty: Optional[float] = 1.0
-    guided_json: Optional[Union[str, dict, BaseModel]] = None
-    guided_regex: Optional[str] = None
-    guided_choice: Optional[List[str]] = None
-    guided_grammar: Optional[str] = None
-    response_format: Optional[ResponseFormat] = None
+    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+    # doc: end-chat-completion-sampling-params
+
+    # doc: begin-chat-completion-extra-params
+    echo: Optional[bool] = Field(
+        default=False,
+        description=(
+            "If true, the new message will be prepended with the last message "
+            "if they belong to the same role."),
+    )
+    add_generation_prompt: Optional[bool] = Field(
+        default=True,
+        description=
+        ("If true, the generation prompt will be added to the chat template. "
+         "This is a parameter used by chat template in tokenizer config of the "
+         "model."),
+    )
+    add_special_tokens: Optional[bool] = Field(
+        default=False,
+        description=(
+            "If true, special tokens (e.g. BOS) will be added to the prompt "
+            "on top of what is added by the chat template. "
+            "For most models, the chat template takes care of adding the "
+            "special tokens so this should be set to False (as is the "
+            "default)."),
+    )
+    documents: Optional[List[Dict[str, str]]] = Field(
+        default=None,
+        description=
+        ("A list of dicts representing documents that will be accessible to "
+         "the model if it is performing RAG (retrieval-augmented generation)."
+         " If the template does not support RAG, this argument will have no "
+         "effect. We recommend that each document should be a dict containing "
+         "\"title\" and \"text\" keys."),
+    )
+    chat_template: Optional[str] = Field(
+        default=None,
+        description=(
+            "A Jinja template to use for this conversion. "
+            "If this is not passed, the model's default chat template will be "
+            "used instead."),
+    )
+    chat_template_kwargs: Optional[Dict[str, Any]] = Field(
+        default=None,
+        description=("Additional kwargs to pass to the template renderer. "
+                     "Will be accessible by the chat template."),
+    )
+    include_stop_str_in_output: Optional[bool] = Field(
+        default=False,
+        description=(
+            "Whether to include the stop string in the output. "
+            "This is only applied when the stop or stop_token_ids is set."),
+    )
+    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
+        default=None,
+        description=("If specified, the output will follow the JSON schema."),
+    )
+    guided_regex: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, the output will follow the regex pattern."),
+    )
+    guided_choice: Optional[List[str]] = Field(
+        default=None,
+        description=(
+            "If specified, the output will be exactly one of the choices."),
+    )
+    guided_grammar: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, the output will follow the context free grammar."),
+    )
     guided_decoding_backend: Optional[str] = Field(
-        default="outlines",
+        default=None,
         description=(
             "If specified, will override the default guided decoding backend "
             "of the server for this specific request. If set, must be either "
             "'outlines' / 'lm-format-enforcer'"))
-
-    def to_sampling_params(self, vocab_size: int) -> SamplingParams:
-        if self.logprobs and not self.top_logprobs:
-            raise ValueError("Top logprobs must be set when logprobs is.")
-        if self.top_k == 0:
-            self.top_k = -1
-
-        logits_processors = []
-        if self.logit_bias:
-            biases = {
-                int(tok): max(-100, min(float(bias), 100))
-                for tok, bias in self.logit_bias.items()
-                if 0 < int(tok) < vocab_size
-            }
-            logits_processors.append(BiasLogitsProcessor(biases))
+    guided_whitespace_pattern: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, will override the default whitespace pattern "
+            "for guided json decoding."))
+
+    # doc: end-chat-completion-extra-params
+
+    def to_sampling_params(
+            self, tokenizer: PreTrainedTokenizer,
+            guided_decode_logits_processor: Optional[LogitsProcessorFunc],
+            default_max_tokens: int) -> SamplingParams:
+        max_tokens = self.max_tokens
+        if max_tokens is None:
+            max_tokens = default_max_tokens
+
+        # We now allow logprobs being true without top_logrobs.
+        logits_processors = get_logits_processors(
+            logit_bias=self.logit_bias,
+            allowed_token_ids=None,
+            tokenizer=tokenizer,
+        )
+        if guided_decode_logits_processor:
+            logits_processors.append(guided_decode_logits_processor)
 
         return SamplingParams(
             n=self.n,
-            max_tokens=self.max_tokens,
+            presence_penalty=self.presence_penalty,
+            frequency_penalty=self.frequency_penalty,
+            repetition_penalty=self.repetition_penalty,
+            temperature=self.temperature,
+            top_p=self.top_p,
+            min_p=self.min_p,
+            seed=self.seed,
+            stop=self.stop,
+            stop_token_ids=self.stop_token_ids,
+            max_tokens=max_tokens,
             min_tokens=self.min_tokens,
             logprobs=self.top_logprobs if self.logprobs else None,
             prompt_logprobs=self.top_logprobs if self.echo else None,
-            temperature=self.temperature,
-            top_p=self.top_p,
+            best_of=self.best_of,
+            top_k=self.top_k,
+            top_a=self.top_a,
             tfs=self.tfs,
             eta_cutoff=self.eta_cutoff,
             epsilon_cutoff=self.epsilon_cutoff,
             typical_p=self.typical_p,
-            presence_penalty=self.presence_penalty,
-            frequency_penalty=self.frequency_penalty,
-            repetition_penalty=self.repetition_penalty,
-            top_k=self.top_k,
-            top_a=self.top_a,
-            min_p=self.min_p,
-            mirostat_mode=self.mirostat_mode,
-            mirostat_tau=self.mirostat_tau,
-            mirostat_eta=self.mirostat_eta,
-            dynatemp_min=self.dynatemp_min,
-            dynatemp_max=self.dynatemp_max,
-            dynatemp_exponent=self.dynatemp_exponent,
             smoothing_factor=self.smoothing_factor,
             smoothing_curve=self.smoothing_curve,
             ignore_eos=self.ignore_eos,
             use_beam_search=self.use_beam_search,
-            stop_token_ids=self.stop_token_ids,
-            custom_token_bans=self.custom_token_bans,
+            early_stopping=self.early_stopping,
             skip_special_tokens=self.skip_special_tokens,
             spaces_between_special_tokens=self.spaces_between_special_tokens,
-            stop=self.stop,
-            best_of=self.best_of,
             include_stop_str_in_output=self.include_stop_str_in_output,
-            seed=self.seed,
+            length_penalty=self.length_penalty,
             logits_processors=logits_processors,
         )
 
+    @model_validator(mode='before')
+    @classmethod
+    def validate_stream_options(cls, values):
+        if (values.get('stream_options') is not None
+                and not values.get('stream')):
+            raise ValueError(
+                "stream_options can only be set if stream is true")
+        return values
+
     @model_validator(mode="before")
     @classmethod
     def check_guided_decoding_count(cls, data):
@@ -179,129 +296,182 @@ class ChatCompletionRequest(BaseModel):
             "guided_regex" in data and data["guided_regex"] is not None,
             "guided_choice" in data and data["guided_choice"] is not None
         ])
+        # you can only use one kind of guided decoding
         if guide_count > 1:
             raise ValueError(
                 "You can only use one kind of guided decoding "
                 "('guided_json', 'guided_regex' or 'guided_choice').")
+        # you can only either use guided decoding or tools, not both
+        if guide_count > 1 and "tool_choice" in data and data[
+                "tool_choice"] != "none":
+            raise ValueError(
+                "You can only either use guided decoding or tools, not both.")
+        return data
+
+    @model_validator(mode="before")
+    @classmethod
+    def check_tool_choice(cls, data):
+        if "tool_choice" in data and data["tool_choice"] != "none":
+            if not isinstance(data["tool_choice"], dict):
+                raise ValueError("Currently only named tools are supported.")
+            if "tools" not in data or data["tools"] is None:
+                raise ValueError(
+                    "When using `tool_choice`, `tools` must be set.")
+        return data
+
+    @model_validator(mode="before")
+    @classmethod
+    def check_logprobs(cls, data):
+        if "top_logprobs" in data and data["top_logprobs"] is not None:
+            if "logprobs" not in data or data["logprobs"] is False:
+                raise ValueError(
+                    "when using `top_logprobs`, `logprobs` must be set to true."
+                )
+            elif data["top_logprobs"] < 0:
+                raise ValueError(
+                    "`top_logprobs` must be a value a positive value.")
         return data
 
 
-class CompletionRequest(BaseModel):
+class CompletionRequest(OpenAIBaseModel):
+    # Ordered by official OpenAI API documentation
+    # https://platform.openai.com/docs/api-reference/completions/create
     model: str
-    # a string, array of strings, array of tokens, or array of token arrays
     prompt: Union[List[int], List[List[int]], str, List[str]]
-    suffix: Optional[str] = None
+    best_of: Optional[int] = None
+    echo: Optional[bool] = False
+    frequency_penalty: Optional[float] = 0.0
+    logit_bias: Optional[Dict[str, float]] = None
+    logprobs: Optional[int] = None
     max_tokens: Optional[int] = 16
-    min_tokens: Optional[int] = 0
+    n: int = 1
+    presence_penalty: Optional[float] = 0.0
+    seed: Optional[int] = Field(None,
+                                ge=torch.iinfo(torch.long).min,
+                                le=torch.iinfo(torch.long).max)
+    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
+    stream: Optional[bool] = False
+    stream_options: Optional[StreamOptions] = None
+    suffix: Optional[str] = None
     temperature: Optional[float] = 1.0
     top_p: Optional[float] = 1.0
+    user: Optional[str] = None
+
+    # doc: begin-completion-sampling-params
+    use_beam_search: Optional[bool] = False
+    top_k: Optional[int] = -1
+    min_p: Optional[float] = 0.0
+    top_a: Optional[float] = 0.0
     tfs: Optional[float] = 1.0
     eta_cutoff: Optional[float] = 0.0
     epsilon_cutoff: Optional[float] = 0.0
     typical_p: Optional[float] = 1.0
-    n: Optional[int] = 1
-    stream: Optional[bool] = False
-    logprobs: Optional[int] = None
-    echo: Optional[bool] = False
-    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
-    seed: Optional[int] = None
-    include_stop_str_in_output: Optional[bool] = False
-    presence_penalty: Optional[float] = 0.0
-    frequency_penalty: Optional[float] = 0.0
-    repetition_penalty: Optional[float] = 1.0
-    best_of: Optional[int] = None
-    logit_bias: Optional[Dict[str, float]] = None
-    user: Optional[str] = None
-    top_k: Optional[int] = -1
-    top_a: Optional[float] = 0.0
-    min_p: Optional[float] = 0.0
-    mirostat_mode: Optional[int] = 0
-    mirostat_tau: Optional[float] = 0.0
-    mirostat_eta: Optional[float] = 0.0
-    dynatemp_min: Optional[float] = Field(0.0,
-                                          validation_alias=AliasChoices(
-                                              "dynatemp_min", "dynatemp_low"),
-                                          description="Aliases: dynatemp_low")
-    dynatemp_max: Optional[float] = Field(0.0,
-                                          validation_alias=AliasChoices(
-                                              "dynatemp_max", "dynatemp_high"),
-                                          description="Aliases: dynatemp_high")
-    dynatemp_exponent: Optional[float] = 1.0
     smoothing_factor: Optional[float] = 0.0
     smoothing_curve: Optional[float] = 1.0
-    ignore_eos: Optional[bool] = False
-    use_beam_search: Optional[bool] = False
-    logprobs: Optional[int] = None
-    prompt_logprobs: Optional[int] = None
+    repetition_penalty: Optional[float] = 1.0
+    length_penalty: Optional[float] = 1.0
+    early_stopping: Optional[bool] = False
     stop_token_ids: Optional[List[int]] = Field(default_factory=list)
-    custom_token_bans: Optional[List[int]] = Field(default_factory=list)
+    ignore_eos: Optional[bool] = False
+    min_tokens: Optional[int] = 0
     skip_special_tokens: Optional[bool] = True
     spaces_between_special_tokens: Optional[bool] = True
-    truncate_prompt_tokens: Optional[conint(ge=1)] = None
-    grammar: Optional[str] = None
-    length_penalty: Optional[float] = 1.0
-    guided_json: Optional[Union[str, dict, BaseModel]] = None
-    guided_regex: Optional[str] = None
-    guided_choice: Optional[List[str]] = None
-    guided_grammar: Optional[str] = None
-    response_format: Optional[ResponseFormat] = None
+    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+    allowed_token_ids: Optional[List[int]] = None
+    include_stop_str_in_output: Optional[bool] = False
+    add_special_tokens: Optional[bool] = False
+    # doc: end-completion-sampling-params
+
+    # doc: begin-completion-extra-params
+    response_format: Optional[ResponseFormat] = Field(
+        default=None,
+        description=
+        ("Similar to chat completion, this parameter specifies the format of "
+         "output. Only {'type': 'json_object'} or {'type': 'text' } is "
+         "supported."),
+    )
+    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
+        default=None,
+        description=("If specified, the output will follow the JSON schema."),
+    )
+    guided_regex: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, the output will follow the regex pattern."),
+    )
+    guided_choice: Optional[List[str]] = Field(
+        default=None,
+        description=(
+            "If specified, the output will be exactly one of the choices."),
+    )
+    guided_grammar: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, the output will follow the context free grammar."),
+    )
     guided_decoding_backend: Optional[str] = Field(
-        default="outlines",
+        default=None,
         description=(
             "If specified, will override the default guided decoding backend "
             "of the server for this specific request. If set, must be one of "
             "'outlines' / 'lm-format-enforcer'"))
+    guided_whitespace_pattern: Optional[str] = Field(
+        default=None,
+        description=(
+            "If specified, will override the default whitespace pattern "
+            "for guided json decoding."))
+
+    # doc: end-completion-extra-params
+
+    def to_sampling_params(
+            self, tokenizer: PreTrainedTokenizer,
+            guided_decode_logits_processor: Optional[LogitsProcessorFunc],
+            default_max_tokens: int) -> SamplingParams:
+        max_tokens = self.max_tokens
+        if max_tokens is None:
+            max_tokens = default_max_tokens
 
-    def to_sampling_params(self, vocab_size: int) -> SamplingParams:
         echo_without_generation = self.echo and self.max_tokens == 0
-        if self.top_k == 0:
-            self.top_k = -1
-
-        logits_processors = []
-        if self.logit_bias:
-            biases = {
-                int(tok): max(-100, min(float(bias), 100))
-                for tok, bias in self.logit_bias.items()
-                if 0 < int(tok) < vocab_size
-            }
-            logits_processors.append(BiasLogitsProcessor(biases))
+
+        logits_processors = get_logits_processors(
+            logit_bias=self.logit_bias,
+            allowed_token_ids=self.allowed_token_ids,
+            tokenizer=tokenizer,
+        )
+        if guided_decode_logits_processor:
+            logits_processors.append(guided_decode_logits_processor)
 
         return SamplingParams(
             n=self.n,
-            max_tokens=self.max_tokens if not echo_without_generation else 1,
-            min_tokens=self.min_tokens,
+            best_of=self.best_of,
+            presence_penalty=self.presence_penalty,
+            frequency_penalty=self.frequency_penalty,
+            repetition_penalty=self.repetition_penalty,
             temperature=self.temperature,
             top_p=self.top_p,
+            top_k=self.top_k,
+            min_p=self.min_p,
+            top_a=self.top_a,
             tfs=self.tfs,
             eta_cutoff=self.eta_cutoff,
             epsilon_cutoff=self.epsilon_cutoff,
             typical_p=self.typical_p,
-            presence_penalty=self.presence_penalty,
-            frequency_penalty=self.frequency_penalty,
-            repetition_penalty=self.repetition_penalty,
-            top_k=self.top_k,
-            top_a=self.top_a,
-            min_p=self.min_p,
-            mirostat_mode=self.mirostat_mode,
-            mirostat_tau=self.mirostat_tau,
-            mirostat_eta=self.mirostat_eta,
-            dynatemp_min=self.dynatemp_min,
-            dynatemp_max=self.dynatemp_max,
-            dynatemp_exponent=self.dynatemp_exponent,
             smoothing_factor=self.smoothing_factor,
             smoothing_curve=self.smoothing_curve,
+            seed=self.seed,
+            stop=self.stop,
+            stop_token_ids=self.stop_token_ids,
             ignore_eos=self.ignore_eos,
-            use_beam_search=self.use_beam_search,
+            max_tokens=max_tokens if not echo_without_generation else 1,
+            min_tokens=self.min_tokens,
             logprobs=self.logprobs,
-            prompt_logprobs=self.prompt_logprobs if self.echo else None,
-            stop_token_ids=self.stop_token_ids,
-            custom_token_bans=self.custom_token_bans,
+            use_beam_search=self.use_beam_search,
+            early_stopping=self.early_stopping,
+            prompt_logprobs=self.logprobs if self.echo else None,
             skip_special_tokens=self.skip_special_tokens,
-            spaces_between_special_tokens=self.spaces_between_special_tokens,
-            stop=self.stop,
-            best_of=self.best_of,
+            spaces_between_special_tokens=(self.spaces_between_special_tokens),
             include_stop_str_in_output=self.include_stop_str_in_output,
-            seed=self.seed,
+            length_penalty=self.length_penalty,
             logits_processors=logits_processors,
             truncate_prompt_tokens=self.truncate_prompt_tokens,
         )
@@ -320,20 +490,55 @@ class CompletionRequest(BaseModel):
                 "('guided_json', 'guided_regex' or 'guided_choice').")
         return data
 
+    @model_validator(mode="before")
+    @classmethod
+    def check_logprobs(cls, data):
+        if "logprobs" in data and data[
+                "logprobs"] is not None and not data["logprobs"] >= 0:
+            raise ValueError("if passed, `logprobs` must be a positive value.")
+        return data
+
+    @model_validator(mode="before")
+    @classmethod
+    def validate_stream_options(cls, data):
+        if data.get("stream_options") and not data.get("stream"):
+            raise ValueError(
+                "Stream options can only be defined when stream is True.")
+        return data
+
+
+class EmbeddingRequest(OpenAIBaseModel):
+    # Ordered by official OpenAI API documentation
+    # https://platform.openai.com/docs/api-reference/embeddings
+    model: str
+    input: Union[List[int], List[List[int]], str, List[str]]
+    encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
+    dimensions: Optional[int] = None
+    user: Optional[str] = None
+
+    # doc: begin-embedding-pooling-params
+    additional_data: Optional[Any] = None
+
+    # doc: end-embedding-pooling-params
+
+    def to_pooling_params(self):
+        return PoolingParams(additional_data=self.additional_data)
 
-class LogProbs(BaseModel):
+
+class CompletionLogProbs(OpenAIBaseModel):
     text_offset: List[int] = Field(default_factory=list)
     token_logprobs: List[Optional[float]] = Field(default_factory=list)
     tokens: List[str] = Field(default_factory=list)
-    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
+    top_logprobs: List[Optional[Dict[str,
+                                     float]]] = Field(default_factory=list)
 
 
-class CompletionResponseChoice(BaseModel):
+class CompletionResponseChoice(OpenAIBaseModel):
     index: int
     text: str
-    logprobs: Optional[LogProbs] = None
-    finish_reason: Optional[Literal["stop", "length"]] = None
-    stop_reason: Union[None, int, str] = Field(
+    logprobs: Optional[CompletionLogProbs] = None
+    finish_reason: Optional[str] = None
+    stop_reason: Optional[Union[int, str]] = Field(
         default=None,
         description=(
             "The stop string or token id that caused the completion "
@@ -342,7 +547,7 @@ class CompletionResponseChoice(BaseModel):
     )
 
 
-class CompletionResponse(BaseModel):
+class CompletionResponse(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
     object: str = "text_completion"
     created: int = Field(default_factory=lambda: int(time.time()))
@@ -351,12 +556,12 @@ class CompletionResponse(BaseModel):
     usage: UsageInfo
 
 
-class CompletionResponseStreamChoice(BaseModel):
+class CompletionResponseStreamChoice(OpenAIBaseModel):
     index: int
     text: str
-    logprobs: Optional[LogProbs] = None
-    finish_reason: Optional[Literal["stop", "length"]] = None
-    stop_reason: Union[None, int, str] = Field(
+    logprobs: Optional[CompletionLogProbs] = None
+    finish_reason: Optional[str] = None
+    stop_reason: Optional[Union[int, str]] = Field(
         default=None,
         description=(
             "The stop string or token id that caused the completion "
@@ -365,7 +570,7 @@ class CompletionResponseStreamChoice(BaseModel):
     )
 
 
-class CompletionStreamResponse(BaseModel):
+class CompletionStreamResponse(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
     object: str = "text_completion"
     created: int = Field(default_factory=lambda: int(time.time()))
@@ -374,83 +579,172 @@ class CompletionStreamResponse(BaseModel):
     usage: Optional[UsageInfo] = Field(default=None)
 
 
-class ChatMessage(BaseModel):
+class EmbeddingResponseData(OpenAIBaseModel):
+    index: int
+    object: str = "embedding"
+    embedding: Union[List[float], str]
+
+
+class EmbeddingResponse(OpenAIBaseModel):
+    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
+    object: str = "list"
+    created: int = Field(default_factory=lambda: int(time.time()))
+    model: str
+    data: List[EmbeddingResponseData]
+    usage: UsageInfo
+
+
+class FunctionCall(OpenAIBaseModel):
+    name: str
+    arguments: str
+
+
+class ToolCall(OpenAIBaseModel):
+    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+    type: Literal["function"] = "function"
+    function: FunctionCall
+
+
+class ChatMessage(OpenAIBaseModel):
     role: str
     content: str
+    tool_calls: List[ToolCall] = Field(default_factory=list)
+
+
+class ChatCompletionLogProb(OpenAIBaseModel):
+    token: str
+    logprob: float = -9999.0
+    bytes: Optional[List[int]] = None
+
+
+class ChatCompletionLogProbsContent(ChatCompletionLogProb):
+    top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
+
+
+class ChatCompletionLogProbs(OpenAIBaseModel):
+    content: Optional[List[ChatCompletionLogProbsContent]] = None
 
 
-class ChatCompletionResponseChoice(BaseModel):
+class ChatCompletionResponseChoice(OpenAIBaseModel):
     index: int
     message: ChatMessage
-    logprobs: Optional[LogProbs] = None
-    finish_reason: Optional[Literal["stop", "length"]] = None
-    stop_reason: Union[None, int, str] = None
+    logprobs: Optional[ChatCompletionLogProbs] = None
+    finish_reason: Optional[str] = None
+    stop_reason: Optional[Union[int, str]] = None
 
 
-class ChatCompletionResponse(BaseModel):
+class ChatCompletionResponse(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
-    object: str = "chat.completion"
+    object: Literal["chat.completion"] = "chat.completion"
     created: int = Field(default_factory=lambda: int(time.time()))
     model: str
     choices: List[ChatCompletionResponseChoice]
     usage: UsageInfo
 
 
-class DeltaMessage(BaseModel):
+class DeltaMessage(OpenAIBaseModel):
     role: Optional[str] = None
     content: Optional[str] = None
+    tool_calls: List[ToolCall] = Field(default_factory=list)
 
 
-class ChatCompletionResponseStreamChoice(BaseModel):
+class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
     index: int
     delta: DeltaMessage
-    logprobs: Optional[LogProbs] = None
-    finish_reason: Optional[Literal["stop", "length"]] = None
-    stop_reason: Union[None, int, str] = None
+    logprobs: Optional[ChatCompletionLogProbs] = None
+    finish_reason: Optional[str] = None
+    stop_reason: Optional[Union[int, str]] = None
 
 
-class ChatCompletionStreamResponse(BaseModel):
+class ChatCompletionStreamResponse(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
-    object: str = "chat.completion.chunk"
+    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
     created: int = Field(default_factory=lambda: int(time.time()))
     model: str
     choices: List[ChatCompletionResponseStreamChoice]
     usage: Optional[UsageInfo] = Field(default=None)
-    logprobs: Optional[LogProbs] = None
 
 
-class EmbeddingsRequest(BaseModel):
-    input: List[str] = Field(
-        ..., description="List of input texts to generate embeddings for.")
-    encoding_format: str = Field(
-        "float",
-        description="Encoding format for the embeddings. "
-        "Can be 'float' or 'base64'.")
-    model: Optional[str] = Field(
-        None,
-        description="Name of the embedding model to use. "
-        "If not provided, the default model will be used.")
+class BatchRequestInput(OpenAIBaseModel):
+    """
+    The per-line object of the batch input file.
+
+    NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
+    """
+
+    # A developer-provided per-request id that will be used to match outputs to
+    # inputs. Must be unique for each request in a batch.
+    custom_id: str
+
+    # The HTTP method to be used for the request. Currently only POST is
+    # supported.
+    method: str
+
+    # The OpenAI API relative URL to be used for the request. Currently
+    # /v1/chat/completions is supported.
+    url: str
+
+    # The parameters of the request.
+    body: ChatCompletionRequest
+
+
+class BatchResponseData(OpenAIBaseModel):
+    # HTTP status code of the response.
+    status_code: int = 200
+
+    # An unique identifier for the API request.
+    request_id: str
+
+    # The body of the response.
+    body: Optional[ChatCompletionResponse] = None
+
+
+class BatchRequestOutput(OpenAIBaseModel):
+    """
+    The per-line object of the batch output and error files
+    """
+
+    id: str
+
+    # A developer-provided per-request id that will be used to match outputs to
+    # inputs.
+    custom_id: str
+
+    response: Optional[BatchResponseData]
+
+    # For requests that failed with a non-HTTP error, this will contain more
+    # information on the cause of the failure.
+    error: Optional[Any]
+
+
+class TokenizeCompletionRequest(OpenAIBaseModel):
+    model: str
+    prompt: str
+    add_special_tokens: bool = Field(default=True)
+
+
+class TokenizeChatRequest(OpenAIBaseModel):
+    model: str
+    messages: List[ChatCompletionMessageParam]
+    add_generation_prompt: bool = Field(default=True)
+    add_special_tokens: bool = Field(default=False)
+
+
+TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
 
 
-class EmbeddingObject(BaseModel):
-    object: str = Field("embedding", description="Type of the object.")
-    embedding: List[float] = Field(
-        ..., description="Embedding values as a list of floats.")
-    index: int = Field(
-        ...,
-        description="Index of the input text corresponding to "
-        "the embedding.")
+class TokenizeResponse(OpenAIBaseModel):
+    tokens: List[int]
+    count: int
+    max_model_len: int
 
 
-class EmbeddingsResponse(BaseModel):
-    object: str = Field("list", description="Type of the response object.")
-    data: List[EmbeddingObject] = Field(
-        ..., description="List of embedding objects.")
-    model: str = Field(..., description="Name of the embedding model used.")
-    usage: UsageInfo = Field(..., description="Information about token usage.")
+class DetokenizeRequest(OpenAIBaseModel):
+    model: Optional[str]
+    tokens: List[int]
 
 
-class Prompt(BaseModel):
+class DetokenizeResponse(OpenAIBaseModel):
     prompt: str
 
 

+ 40 - 0
aphrodite/endpoints/openai/rpc/__init__.py

@@ -0,0 +1,40 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional, Union
+
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.inputs import PromptInputs
+from aphrodite.lora.request import LoRARequest
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+
+APHRODITE_RPC_SUCCESS_STR = "SUCCESS"
+APHRODITE_RPC_HEALTHY_STR = "HEALTHY"
+
+
+@dataclass
+class RPCGenerateRequest:
+    inputs: PromptInputs
+    sampling_params: SamplingParams
+    request_id: str
+    lora_request: Optional[LoRARequest] = None
+    prompt_adapter_request: Optional[PromptAdapterRequest] = None
+
+
+@dataclass
+class RPCAbortRequest:
+    request_id: str
+
+
+class RPCUtilityRequest(Enum):
+    IS_SERVER_READY = 1
+    GET_MODEL_CONFIG = 2
+    GET_DECODING_CONFIG = 3
+    GET_PARALLEL_CONFIG = 4
+    GET_SCHEDULER_CONFIG = 5
+    GET_LORA_CONFIG = 6
+    DO_LOG_STATS = 7
+    CHECK_HEALTH = 8
+
+
+RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
+                         RPCUtilityRequest]

+ 236 - 0
aphrodite/endpoints/openai/rpc/client.py

@@ -0,0 +1,236 @@
+from contextlib import contextmanager
+from typing import Any, AsyncIterator, Optional
+
+import cloudpickle
+import zmq
+import zmq.asyncio
+
+from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig)
+from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
+                                            APHRODITE_RPC_SUCCESS_STR,
+                                            RPC_REQUEST_TYPE, RPCAbortRequest,
+                                            RPCGenerateRequest,
+                                            RPCUtilityRequest)
+from aphrodite.inputs import PromptInputs
+from aphrodite.lora.request import LoRARequest
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+from aphrodite.transformers_utils.tokenizer_group import (
+    init_tokenizer_from_configs)
+
+
+class AsyncEngineRPCClient:
+
+    def __init__(self, port: int):
+        self.context = zmq.asyncio.Context()
+        self.path = f"tcp://localhost:{port}"
+
+    async def setup(self):
+        """Setup the client before it starts sending server requests."""
+
+        # Wait until server is ready.
+        await self.wait_for_server()
+
+        # Get the configs.
+        self.model_config = await self._get_model_config_rpc()
+        self.decoding_config = await self._get_decoding_config_rpc()
+
+        # Create the tokenizer group.
+        # TODO: refactor OAI server to avoid needing this info.
+        self.tokenizer = init_tokenizer_from_configs(
+            model_config=self.model_config,
+            scheduler_config=(await self._get_scheduler_config_rpc()),
+            parallel_config=(await self._get_parallel_config_rpc()),
+            enable_lora=bool(await self._get_lora_config_rpc()),
+        )
+
+    def close(self):
+        """Destroy the ZeroMQ Context."""
+        self.context.destroy()
+
+    @contextmanager
+    def socket(self):
+        # Ensure client sockets are always closed after use
+
+        # Connect to RPC socket for Request-Reply pattern,
+        # Note that we use DEALER to enable asynchronous communication
+        # to enable streaming.
+        socket = self.context.socket(zmq.constants.DEALER)
+        try:
+            socket.connect(self.path)
+            yield socket
+        finally:
+            socket.close()
+
+    async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
+                                         expected_type: Any,
+                                         error_message: str) -> Any:
+        """Send an RPC request that is expecting data back."""
+
+        with self.socket() as socket:
+
+            # Ping RPCServer with a request.
+            await socket.send(cloudpickle.dumps(request))
+
+            # Await the data from the Server.
+            data = cloudpickle.loads(await socket.recv())
+
+        if not isinstance(data, expected_type):
+            # LoRAConfig can be None.
+            if expected_type == LoRAConfig and data is None:
+                pass
+            else:
+                raise ValueError(error_message)
+
+        return data
+
+    async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
+                                        error_message: str):
+        """Send one-way RPC request to trigger an action."""
+        with self.socket() as socket:
+            # Ping RPC Server with request.
+            await socket.send(cloudpickle.dumps(request))
+
+            # Await acknowledgement from RPCServer.
+            response = cloudpickle.loads(await socket.recv())
+
+        if not isinstance(response, str) or response != \
+            APHRODITE_RPC_SUCCESS_STR:
+            raise ValueError(error_message)
+
+        return response
+
+    async def get_tokenizer(self, lora_request: LoRARequest):
+        return await self.tokenizer.get_lora_tokenizer_async(lora_request)
+
+    async def get_decoding_config(self) -> DecodingConfig:
+        return self.decoding_config
+
+    async def get_model_config(self) -> ModelConfig:
+        return self.model_config
+
+    async def wait_for_server(self):
+        """Wait for the RPCServer to start up."""
+
+        await self._send_one_way_rpc_request(
+            request=RPCUtilityRequest.IS_SERVER_READY,
+            error_message="Unable to start RPC Server.")
+
+    async def _get_model_config_rpc(self) -> ModelConfig:
+        """Get the ModelConfig object from the RPC Server"""
+
+        return await self._send_get_data_rpc_request(
+            RPCUtilityRequest.GET_MODEL_CONFIG,
+            expected_type=ModelConfig,
+            error_message="Could not get ModelConfig from RPC Server")
+
+    async def _get_decoding_config_rpc(self) -> DecodingConfig:
+        """Get DecodingConfig from the RPCServer"""
+
+        return await self._send_get_data_rpc_request(
+            RPCUtilityRequest.GET_DECODING_CONFIG,
+            expected_type=DecodingConfig,
+            error_message="Could not get DecodingConfig from RPC Server")
+
+    async def _get_parallel_config_rpc(self) -> ParallelConfig:
+        """Get ParallelConfig from the RPCServer"""
+
+        return await self._send_get_data_rpc_request(
+            RPCUtilityRequest.GET_PARALLEL_CONFIG,
+            expected_type=ParallelConfig,
+            error_message="Could not get ParallelConfig from RPC Server")
+
+    async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
+        """Get SchedulerConfig from the RPCServer"""
+
+        return await self._send_get_data_rpc_request(
+            RPCUtilityRequest.GET_SCHEDULER_CONFIG,
+            expected_type=SchedulerConfig,
+            error_message="Could not get SchedulerConfig from RPC Server")
+
+    async def _get_lora_config_rpc(self):
+        """Get LoRAConfig from the RPCServer"""
+
+        return await self._send_get_data_rpc_request(
+            RPCUtilityRequest.GET_LORA_CONFIG,
+            expected_type=LoRAConfig,
+            error_message="Could not get LoRAConfig from RPC Server")
+
+    async def abort(self, request_id: str):
+        """Send an ABORT_REQUEST signal to the RPC Server"""
+
+        await self._send_one_way_rpc_request(
+            request=RPCAbortRequest(request_id),
+            error_message=f"RPCAbortRequest {request_id} failed")
+
+    async def do_log_stats(self):
+        """Send a DO_LOG_STATS signal to the RPC Server"""
+
+        await self._send_one_way_rpc_request(
+            request=RPCUtilityRequest.DO_LOG_STATS,
+            error_message="RPCRequest DO_LOG_STATS failed.")
+
+    async def generate(
+        self,
+        inputs: PromptInputs,
+        sampling_params: SamplingParams,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None
+    ) -> AsyncIterator[RequestOutput]:
+        """Send an RPCGenerateRequest to the RPCServer and stream responses."""
+
+        with self.socket() as socket:
+
+            # Send RPCGenerateRequest to the RPCServer.
+            await socket.send_multipart([
+                cloudpickle.dumps(
+                    RPCGenerateRequest(
+                        inputs=inputs,
+                        sampling_params=sampling_params,
+                        request_id=request_id,
+                        lora_request=lora_request,
+                        prompt_adapter_request=prompt_adapter_request))
+            ])
+
+            # Stream back the results from the RPC Server.
+            while True:
+                message = await socket.recv()
+                request_output = cloudpickle.loads(message)
+
+                if isinstance(request_output, Exception):
+                    raise request_output
+
+                if request_output.finished:
+                    break
+                yield request_output
+
+            yield request_output
+
+    async def check_health(self) -> None:
+        """Raise if unhealthy"""
+
+        with self.socket() as socket:
+
+            # Ping RPCServer with CHECK_HEALTH request.
+            await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
+                              )
+
+            # Await the reply from the server.
+            # TODO: do we need an internal timeout here?
+            # Or do we expect the external probe to timeout and let this chill?
+            health_message = cloudpickle.loads(await socket.recv())
+
+        if isinstance(health_message, Exception):
+            raise health_message
+
+        if health_message != APHRODITE_RPC_HEALTHY_STR:
+            raise ValueError("Expected healthy response from backend but got "
+                             f"{health_message}")
+
+    async def encode(self, *args,
+                     **kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
+        raise NotImplementedError(
+            "Embeddings not supported with multiprocessing backend")

+ 205 - 0
aphrodite/endpoints/openai/rpc/server.py

@@ -0,0 +1,205 @@
+import asyncio
+import signal
+from typing import Any, Coroutine
+
+import cloudpickle
+import zmq
+import zmq.asyncio
+from loguru import logger
+from typing_extensions import Never
+
+from aphrodite import AsyncAphrodite, AsyncEngineArgs
+from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
+                                            APHRODITE_RPC_SUCCESS_STR,
+                                            RPCAbortRequest,
+                                            RPCGenerateRequest,
+                                            RPCUtilityRequest)
+
+
+class AsyncEngineRPCServer:
+
+    def __init__(self, async_engine_args: AsyncEngineArgs, port: int):
+        # Initialize engine first.
+        self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
+
+        # Initialize context.
+        self.context = zmq.asyncio.Context()
+
+        # Init socket for readiness state.
+        self.socket = self.context.socket(zmq.constants.ROUTER)
+        # NOTE numeric form of localhost should be used for zmq bind(),
+        # see https://stackoverflow.com/a/8958414
+        self.socket.bind(f"tcp://127.0.0.1:{port}")
+
+    def cleanup(self):
+        """Cleanup all resources."""
+        self.socket.close()
+        self.context.destroy()
+
+    async def get_model_config(self, identity):
+        """Send the ModelConfig"""
+        model_config = await self.engine.get_model_config()
+
+        await self.socket.send_multipart(
+            [identity, cloudpickle.dumps(model_config)])
+
+    async def get_decoding_config(self, identity):
+        """Send the DecodingConfig"""
+        decoding_config = await self.engine.get_decoding_config()
+
+        await self.socket.send_multipart(
+            [identity, cloudpickle.dumps(decoding_config)])
+
+    async def get_lora_config(self, identity):
+        lora_config = await self.engine.get_lora_config()
+
+        await self.socket.send_multipart(
+            [identity, cloudpickle.dumps(lora_config)])
+
+    async def get_scheduler_config(self, identity):
+        """Send the SchedulerConfig"""
+        parallel_config = await self.engine.get_scheduler_config()
+
+        await self.socket.send_multipart(
+            [identity, cloudpickle.dumps(parallel_config)])
+
+    async def get_parallel_config(self, identity):
+        """Send the ParallelConfig"""
+        parallel_config = await self.engine.get_parallel_config()
+
+        await self.socket.send_multipart(
+            [identity, cloudpickle.dumps(parallel_config)])
+
+    async def do_log_stats(self, identity):
+        """Log stats and confirm success."""
+        await self.engine.do_log_stats()
+
+        await self.socket.send_multipart([
+            identity,
+            cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
+        ])
+
+    async def is_server_ready(self, identity):
+        """Notify the client that we are ready."""
+        await self.socket.send_multipart([
+            identity,
+            cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
+        ])
+
+    async def abort(self, identity, request: RPCAbortRequest):
+        """Abort request and notify the client of success."""
+        # Abort the request in the llm engine.
+        await self.engine.abort(request.request_id)
+
+        # Send confirmation to the client.
+        await self.socket.send_multipart([
+            identity,
+            cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
+        ])
+
+    async def generate(self, identity, generate_request: RPCGenerateRequest):
+        try:
+            results_generator = self.engine.generate(
+                generate_request.inputs,
+                sampling_params=generate_request.sampling_params,
+                request_id=generate_request.request_id,
+                lora_request=generate_request.lora_request,
+                prompt_adapter_request=generate_request.prompt_adapter_request)
+
+            async for request_output in results_generator:
+                await self.socket.send_multipart(
+                    [identity, cloudpickle.dumps(request_output)])
+
+        except Exception as e:
+            ### Notify client of all failures
+            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
+
+    async def check_health(self, identity):
+        try:
+            await self.engine.check_health()
+            await self.socket.send_multipart(
+                [identity,
+                 cloudpickle.dumps(APHRODITE_RPC_HEALTHY_STR)])
+        except Exception as e:
+            await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
+
+    def _make_handler_coro(self, identity,
+                           message) -> Coroutine[Any, Any, Never]:
+        """Route the zmq message to the handler coroutine."""
+
+        request = cloudpickle.loads(message)
+
+        if isinstance(request, RPCGenerateRequest):
+            return self.generate(identity, request)
+
+        elif isinstance(request, RPCAbortRequest):
+            return self.abort(identity, request)
+
+        elif isinstance(request, RPCUtilityRequest):
+            if request == RPCUtilityRequest.GET_MODEL_CONFIG:
+                return self.get_model_config(identity)
+            elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
+                return self.get_parallel_config(identity)
+            elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
+                return self.get_decoding_config(identity)
+            elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
+                return self.get_scheduler_config(identity)
+            elif request == RPCUtilityRequest.GET_LORA_CONFIG:
+                return self.get_lora_config(identity)
+            elif request == RPCUtilityRequest.DO_LOG_STATS:
+                return self.do_log_stats(identity)
+            elif request == RPCUtilityRequest.IS_SERVER_READY:
+                return self.is_server_ready(identity)
+            elif request == RPCUtilityRequest.CHECK_HEALTH:
+                return self.check_health(identity)
+            else:
+                raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
+
+        else:
+            raise ValueError(f"Unknown RPCRequest type: {request}")
+
+    async def run_server_loop(self):
+        """Inner RPC Server Loop"""
+
+        running_tasks = set()
+        while True:
+            # Wait for a request.
+            identity, message = await self.socket.recv_multipart()
+
+            # Process the request async.
+            task = asyncio.create_task(
+                self._make_handler_coro(identity, message))
+
+            # We need to keep around a strong reference to the task,
+            # to avoid the task disappearing mid-execution as running tasks
+            # can be GC'ed. Below is a common "fire-and-forget" tasks
+            # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
+            running_tasks.add(task)
+            task.add_done_callback(running_tasks.discard)
+
+
+async def run_server(server: AsyncEngineRPCServer):
+    # Put the server task into the asyncio loop.
+    loop = asyncio.get_running_loop()
+    server_task = loop.create_task(server.run_server_loop())
+
+    # Interruption handling.
+    def signal_handler() -> None:
+        # Kill the server on interrupt / terminate
+        server_task.cancel()
+
+    loop.add_signal_handler(signal.SIGINT, signal_handler)
+    loop.add_signal_handler(signal.SIGTERM, signal_handler)
+
+    try:
+        await server_task
+    except asyncio.CancelledError:
+        logger.info("Aphrodite ZMQ RPC Server was interrupted.")
+    finally:
+        # Clean up all resources.
+        server.cleanup()
+
+
+def run_rpc_server(async_engine_args: AsyncEngineArgs, port: int):
+    server = AsyncEngineRPCServer(async_engine_args, port)
+    asyncio.run(run_server(server))

+ 159 - 0
aphrodite/endpoints/openai/run_batch.py

@@ -0,0 +1,159 @@
+import asyncio
+from io import StringIO
+from typing import Awaitable, List
+
+import aiohttp
+from loguru import logger
+
+from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
+from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
+                                                 BatchRequestOutput,
+                                                 BatchResponseData,
+                                                 ChatCompletionResponse,
+                                                 ErrorResponse)
+from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.version import __version__ as APHRODITE_VERSION
+
+
+def parse_args():
+    parser = FlexibleArgumentParser(
+        description="Aphrodite OpenAI-Compatible batch runner.")
+    parser.add_argument(
+        "-i",
+        "--input-file",
+        required=True,
+        type=str,
+        help=
+        "The path or url to a single input file. Currently supports local file "
+        "paths, or the http protocol (http or https). If a URL is specified, "
+        "the file should be available via HTTP GET.")
+    parser.add_argument(
+        "-o",
+        "--output-file",
+        required=True,
+        type=str,
+        help="The path or url to a single output file. Currently supports "
+        "local file paths, or web (http or https) urls. If a URL is specified,"
+        " the file should be available via HTTP PUT.")
+    parser.add_argument("--response-role",
+                        type=str,
+                        default="assistant",
+                        help="The role name to return if "
+                        "`request.add_generation_prompt=true`.")
+    parser.add_argument("--max-log-len",
+                        type=int,
+                        default=0,
+                        help="Max number of prompt characters or prompt "
+                        "ID numbers being printed in log."
+                        "\n\nDefault: 0")
+
+    parser = AsyncEngineArgs.add_cli_args(parser)
+    return parser.parse_args()
+
+
+async def read_file(path_or_url: str) -> str:
+    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
+        async with aiohttp.ClientSession() as session, \
+                   session.get(path_or_url) as resp:
+            return await resp.text()
+    else:
+        with open(path_or_url, "r", encoding="utf-8") as f:
+            return f.read()
+
+
+async def write_file(path_or_url: str, data: str) -> None:
+    if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
+        async with aiohttp.ClientSession() as session, \
+                   session.put(path_or_url, data=data.encode("utf-8")):
+            pass
+    else:
+        # We should make this async, but as long as this is always run as a
+        # standalone program, blocking the event loop won't effect performance
+        # in this particular case.
+        with open(path_or_url, "w", encoding="utf-8") as f:
+            f.write(data)
+
+
+async def run_request(chat_serving: OpenAIServingChat,
+                      request: BatchRequestInput) -> BatchRequestOutput:
+    chat_request = request.body
+    chat_response = await chat_serving.create_chat_completion(chat_request)
+
+    if isinstance(chat_response, ChatCompletionResponse):
+        batch_output = BatchRequestOutput(
+            id=f"aphrodite-{random_uuid()}",
+            custom_id=request.custom_id,
+            response=BatchResponseData(
+                body=chat_response,
+                request_id=f"aphrodite-batch-{random_uuid()}"),
+            error=None,
+        )
+    elif isinstance(chat_response, ErrorResponse):
+        batch_output = BatchRequestOutput(
+            id=f"aphrodite-{random_uuid()}",
+            custom_id=request.custom_id,
+            response=BatchResponseData(
+                status_code=chat_response.code,
+                request_id=f"aphrodite-batch-{random_uuid()}"),
+            error=chat_response,
+        )
+    else:
+        raise ValueError("Request must not be sent in stream mode")
+
+    return batch_output
+
+
+async def main(args):
+    if args.served_model_name is not None:
+        served_model_names = args.served_model_name
+    else:
+        served_model_names = [args.model]
+
+    engine_args = AsyncEngineArgs.from_cli_args(args)
+    engine = AsyncAphrodite.from_engine_args(engine_args)
+
+    # When using single Aphrodite without engine_use_ray
+    model_config = await engine.get_model_config()
+
+    if args.disable_log_requests:
+        request_logger = None
+    else:
+        request_logger = RequestLogger(max_log_len=args.max_log_len)
+
+    openai_serving_chat = OpenAIServingChat(
+        engine,
+        model_config,
+        served_model_names,
+        args.response_role,
+        lora_modules=None,
+        prompt_adapters=None,
+        request_logger=request_logger,
+        chat_template=None,
+    )
+
+    # Submit all requests in the file to the engine "concurrently".
+    response_futures: List[Awaitable[BatchRequestOutput]] = []
+    for request_json in (await read_file(args.input_file)).strip().split("\n"):
+        request = BatchRequestInput.model_validate_json(request_json)
+        response_futures.append(run_request(openai_serving_chat, request))
+
+    responses = await asyncio.gather(*response_futures)
+
+    output_buffer = StringIO()
+    for response in responses:
+        print(response.model_dump_json(), file=output_buffer)
+
+    output_buffer.seek(0)
+    await write_file(args.output_file, output_buffer.read().strip())
+
+
+if __name__ == "__main__":
+    args = parse_args()
+
+    logger.info(f"Aphrodite API server version {APHRODITE_VERSION}")
+    logger.info(f"args: {args}")
+
+    asyncio.run(main(args))

+ 331 - 125
aphrodite/endpoints/openai/serving_chat.py

@@ -1,45 +1,66 @@
-import codecs
 import time
-from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
+from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
+from typing import Sequence as GenericSequence
+from typing import Union
 
 from fastapi import Request
 from loguru import logger
+from transformers import PreTrainedTokenizer
 
+from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.sequence import Logprob
 from aphrodite.common.utils import random_uuid
+from aphrodite.endpoints.chat_utils import (ConversationMessage,
+                                            load_chat_template,
+                                            parse_chat_messages)
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest,
-    ChatCompletionResponse,
-    ChatCompletionResponseChoice,
-    ChatCompletionResponseStreamChoice,
-    ChatCompletionStreamResponse,
-    ChatMessage,
-    DeltaMessage,
-    ErrorResponse,
-    UsageInfo,
-)
-from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.modeling.guided_decoding import (
-    get_guided_decoding_logits_processor)
+    ChatCompletionLogProb, ChatCompletionLogProbs,
+    ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
+    ChatCompletionRequest, ChatCompletionResponse,
+    ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
+    ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
+    FunctionCall, ToolCall, UsageInfo)
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       OpenAIServing,
+                                                       PromptAdapterPath)
+from aphrodite.engine.protocol import AsyncEngineClient
+from aphrodite.inputs import PromptInputs
+from aphrodite.multimodal import MultiModalDataDict
 
 
 class OpenAIServingChat(OpenAIServing):
 
-    def __init__(self,
-                 engine: AsyncAphrodite,
-                 served_model: str,
-                 response_role: str,
-                 lora_modules: Optional[List[LoRA]] = None,
-                 chat_template=None):
-        super().__init__(engine=engine,
-                         served_model=served_model,
-                         lora_modules=lora_modules)
+    def __init__(
+        self,
+        async_engine_client: AsyncEngineClient,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        response_role: str,
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+        return_tokens_as_token_ids: bool = False,
+    ):
+        super().__init__(async_engine_client=async_engine_client,
+                         model_config=model_config,
+                         served_model_names=served_model_names,
+                         lora_modules=lora_modules,
+                         prompt_adapters=prompt_adapters,
+                         request_logger=request_logger,
+                         return_tokens_as_token_ids=return_tokens_as_token_ids)
+
         self.response_role = response_role
-        self._load_chat_template(chat_template)
+        # If this is None we use the tokenizer's default chat template
+        self.chat_template = load_chat_template(chat_template)
 
     async def create_chat_completion(
-        self, request: ChatCompletionRequest, raw_request: Request
+        self,
+        request: ChatCompletionRequest,
+        raw_request: Optional[Request] = None
     ) -> Union[ErrorResponse, AsyncGenerator[str, None],
                ChatCompletionResponse]:
         """Completion API similar to OpenAI's API.
@@ -56,44 +77,98 @@ class OpenAIServingChat(OpenAIServing):
             return error_check_ret
 
         try:
-            prompt = self.tokenizer.apply_chat_template(
-                conversation=request.messages,
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
+            model_config = self.model_config
+            tokenizer = await self.async_engine_client.get_tokenizer(
+                lora_request)
+            conversation, mm_futures = parse_chat_messages(
+                request.messages, model_config, tokenizer)
+
+            tool_dicts = None if request.tools is None else [
+                tool.model_dump() for tool in request.tools
+            ]
+
+            prompt = tokenizer.apply_chat_template(
+                conversation=conversation,
                 tokenize=False,
-                add_generation_prompt=request.add_generation_prompt)
+                add_generation_prompt=request.add_generation_prompt,
+                tools=tool_dicts,
+                documents=request.documents,
+                chat_template=request.chat_template or self.chat_template,
+                **(request.chat_template_kwargs or {}),
+            )
+            assert isinstance(prompt, str)
         except Exception as e:
-            logger.error(
-                f"Error in applying chat template from request: {str(e)}")
+            logger.error(f"Error in applying chat template from request: {e}")
             return self.create_error_response(str(e))
 
-        request_id = f"cmpl-{random_uuid()}"
+        mm_data: Optional[MultiModalDataDict] = None
+        try:
+            if len(mm_futures):
+                # since we support only single mm data currently
+                assert len(
+                    mm_futures
+                ) == 1, "Multiple 'image_url' input is currently not supported."
+                mm_data = await mm_futures[0]
+        except Exception as e:
+            logger.error(f"Error in loading multi-modal data: {e}")
+            return self.create_error_response(str(e))
+
+        request_id = f"chat-{random_uuid()}"
         try:
-            # Tokenize/detokenize depending on prompt format (string/token list)
-            prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
-                request, prompt=prompt)
-            sampling_params = request.to_sampling_params(
-                self.tokenizer.vocab_size)
-            lora_request = self._maybe_get_lora(request)
             guided_decode_logits_processor = (
-                await get_guided_decoding_logits_processor(
-                    request.guided_decoding_backend, request, await
-                    self.engine.get_tokenizer()))
-            if guided_decode_logits_processor:
-                sampling_params.logits_processors.append(
-                    guided_decode_logits_processor)
+                await self._guided_decode_logits_processor(request, tokenizer))
+
+            prompt_inputs = self._tokenize_prompt_input(
+                request,
+                tokenizer,
+                prompt,
+                truncate_prompt_tokens=request.truncate_prompt_tokens,
+                add_special_tokens=request.add_special_tokens,
+            )
+
+            sampling_params = request.to_sampling_params(
+                tokenizer,
+                guided_decode_logits_processor,
+                default_max_tokens=self.max_model_len -
+                len(prompt_inputs["prompt_token_ids"]))
+
+            self._log_inputs(request_id,
+                             prompt_inputs,
+                             params=sampling_params,
+                             lora_request=lora_request,
+                             prompt_adapter_request=prompt_adapter_request)
+
+            engine_inputs: PromptInputs = {
+                "prompt_token_ids": prompt_inputs["prompt_token_ids"],
+            }
+            if mm_data is not None:
+                engine_inputs["multi_modal_data"] = mm_data
+
+            result_generator = self.async_engine_client.generate(
+                engine_inputs,
+                sampling_params,
+                request_id,
+                lora_request=lora_request,
+                prompt_adapter_request=prompt_adapter_request,
+            )
         except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
 
-        result_generator = self.engine.generate(prompt_text, sampling_params,
-                                                request_id, prompt_ids,
-                                                lora_request)
         # Streaming response
         if request.stream:
             return self.chat_completion_stream_generator(
-                request, result_generator, request_id)
+                request, result_generator, request_id, conversation, tokenizer)
         else:
             try:
                 return await self.chat_completion_full_generator(
-                    request, raw_request, result_generator, request_id)
+                    request, raw_request, result_generator, request_id,
+                    conversation, tokenizer)
             except ValueError as e:
                 # TODO: Use an aphrodite-specific Validation Error
                 return self.create_error_response(str(e))
@@ -105,22 +180,25 @@ class OpenAIServingChat(OpenAIServing):
             return request.messages[-1]["role"]
 
     async def chat_completion_stream_generator(
-            self, request: ChatCompletionRequest,
-            result_generator: AsyncIterator[RequestOutput], request_id: str
-    ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
-
-        model_name = request.model
+        self,
+        request: ChatCompletionRequest,
+        result_generator: AsyncIterator[RequestOutput],
+        request_id: str,
+        conversation: List[ConversationMessage],
+        tokenizer: PreTrainedTokenizer,
+    ) -> AsyncGenerator[str, None]:
+        model_name = self.served_model_names[0]
         created_time = int(time.time())
         chunk_object_type = "chat.completion.chunk"
         first_iteration = True
 
         # Send response for each token for each request.n (index)
-        previous_texts = [""] * request.n
-        previous_num_tokens = [0] * request.n
-        finish_reason_sent = [False] * request.n
+        num_choices = 1 if request.n is None else request.n
+        previous_texts = [""] * num_choices
+        previous_num_tokens = [0] * num_choices
+        finish_reason_sent = [False] * num_choices
         try:
             async for res in result_generator:
-                res: RequestOutput
                 # We need to do it here, because if there are exceptions in
                 # the result_generator, it needs to be sent as the FIRST
                 # response (by the try...catch).
@@ -128,7 +206,7 @@ class OpenAIServingChat(OpenAIServing):
                     # Send first response for each request.n (index) with
                     # the role
                     role = self.get_chat_request_role(request)
-                    for i in range(request.n):
+                    for i in range(num_choices):
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
                             delta=DeltaMessage(role=role),
@@ -140,6 +218,16 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            if (request.stream_options.continuous_usage_stats):
+                                prompt_tokens = len(res.prompt_token_ids)
+                                usage = UsageInfo(prompt_tokens=prompt_tokens,
+                                                  completion_tokens=0,
+                                                  total_tokens=prompt_tokens)
+                                chunk.usage = usage
+                            else:
+                                chunk.usage = None
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
 
@@ -147,28 +235,39 @@ class OpenAIServingChat(OpenAIServing):
                     # last message
                     if request.echo:
                         last_msg_content = ""
-                        if request.messages and isinstance(
-                                request.messages,
-                                list) and request.messages[-1].get(
-                                    "content") and request.messages[-1].get(
-                                        "role") == role:
-                            last_msg_content = request.messages[-1]["content"]
+                        if conversation and conversation[-1].get(
+                                "content") and conversation[-1].get(
+                                    "role") == role:
+                            last_msg_content = conversation[-1]["content"]
 
                         if last_msg_content:
-                            for i in range(request.n):
+                            for i in range(num_choices):
                                 choice_data = (
                                     ChatCompletionResponseStreamChoice(
                                         index=i,
                                         delta=DeltaMessage(
                                             content=last_msg_content),
+                                        logprobs=None,
                                         finish_reason=None))
                                 chunk = ChatCompletionStreamResponse(
                                     id=request_id,
                                     object=chunk_object_type,
                                     created=created_time,
                                     choices=[choice_data],
-                                    logprobs=None,
                                     model=model_name)
+                                if (request.stream_options and
+                                        request.stream_options.include_usage):
+                                    if (request.stream_options.
+                                            continuous_usage_stats):
+                                        prompt_tokens = len(
+                                            res.prompt_token_ids)
+                                        usage = UsageInfo(
+                                            prompt_tokens=prompt_tokens,
+                                            completion_tokens=0,
+                                            total_tokens=prompt_tokens)
+                                        chunk.usage = usage
+                                    else:
+                                        chunk.usage = None
                                 data = chunk.model_dump_json(
                                     exclude_unset=True)
                                 yield f"data: {data}\n\n"
@@ -181,15 +280,17 @@ class OpenAIServingChat(OpenAIServing):
                         continue
 
                     delta_token_ids = output.token_ids[previous_num_tokens[i]:]
-                    top_logprobs = output.logprobs[
+                    out_logprobs = output.logprobs[
                         previous_num_tokens[i]:] if output.logprobs else None
 
-                    if request.logprobs:
-                        logprobs = self._create_logprobs(
+                    if request.logprobs and request.top_logprobs is not None:
+                        assert out_logprobs is not None, (
+                            "Did not output logprobs")
+                        logprobs = self._create_chat_logprobs(
                             token_ids=delta_token_ids,
-                            top_logprobs=top_logprobs,
-                            num_output_top_logprobs=request.logprobs,
-                            initial_text_offset=len(previous_texts[i]),
+                            top_logprobs=out_logprobs,
+                            tokenizer=tokenizer,
+                            num_output_top_logprobs=request.top_logprobs,
                         )
                     else:
                         logprobs = None
@@ -197,11 +298,24 @@ class OpenAIServingChat(OpenAIServing):
                     delta_text = output.text[len(previous_texts[i]):]
                     previous_texts[i] = output.text
                     previous_num_tokens[i] = len(output.token_ids)
+
+                    if request.tool_choice and type(
+                            request.tool_choice
+                    ) is ChatCompletionNamedToolChoiceParam:
+                        delta_message = DeltaMessage(tool_calls=[
+                            ToolCall(function=FunctionCall(
+                                name=request.tool_choice.function.name,
+                                arguments=delta_text))
+                        ])
+                    else:
+                        delta_message = DeltaMessage(content=delta_text)
+
                     if output.finish_reason is None:
                         # Send token-by-token response for each request.n
+
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
-                            delta=DeltaMessage(content=delta_text),
+                            delta=delta_message,
                             logprobs=logprobs,
                             finish_reason=None)
                         chunk = ChatCompletionStreamResponse(
@@ -210,20 +324,28 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            if (request.stream_options.continuous_usage_stats):
+                                prompt_tokens = len(res.prompt_token_ids)
+                                completion_tokens = len(output.token_ids)
+                                usage = UsageInfo(
+                                    prompt_tokens=prompt_tokens,
+                                    completion_tokens=completion_tokens,
+                                    total_tokens=prompt_tokens +
+                                    completion_tokens,
+                                )
+                                chunk.usage = usage
+                            else:
+                                chunk.usage = None
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
                     else:
                         # Send the finish response for each request.n only once
                         prompt_tokens = len(res.prompt_token_ids)
-                        final_usage = UsageInfo(
-                            prompt_tokens=prompt_tokens,
-                            completion_tokens=previous_num_tokens[i],
-                            total_tokens=prompt_tokens +
-                            previous_num_tokens[i],
-                        )
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
-                            delta=DeltaMessage(content=delta_text),
+                            delta=delta_message,
                             logprobs=logprobs,
                             finish_reason=output.finish_reason,
                             stop_reason=output.stop_reason)
@@ -233,12 +355,43 @@ class OpenAIServingChat(OpenAIServing):
                             created=created_time,
                             choices=[choice_data],
                             model=model_name)
-                        if final_usage is not None:
-                            chunk.usage = final_usage
-                        data = chunk.model_dump_json(exclude_unset=True,
-                                                     exclude_none=True)
+                        if (request.stream_options
+                                and request.stream_options.include_usage):
+                            if (request.stream_options.continuous_usage_stats):
+                                prompt_tokens = len(res.prompt_token_ids)
+                                completion_tokens = len(output.token_ids)
+                                usage = UsageInfo(
+                                    prompt_tokens=prompt_tokens,
+                                    completion_tokens=completion_tokens,
+                                    total_tokens=prompt_tokens +
+                                    completion_tokens,
+                                )
+                                chunk.usage = usage
+                            else:
+                                chunk.usage = None
+                        data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
                         finish_reason_sent[i] = True
+
+            if (request.stream_options
+                    and request.stream_options.include_usage):
+                final_usage = UsageInfo(
+                    prompt_tokens=prompt_tokens,
+                    completion_tokens=previous_num_tokens[i],
+                    total_tokens=prompt_tokens + previous_num_tokens[i],
+                )
+
+                final_usage_chunk = ChatCompletionStreamResponse(
+                    id=request_id,
+                    object=chunk_object_type,
+                    created=created_time,
+                    choices=[],
+                    model=model_name,
+                    usage=final_usage)
+                final_usage_data = (final_usage_chunk.model_dump_json(
+                    exclude_unset=True, exclude_none=True))
+                yield f"data: {final_usage_data}\n\n"
+
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             data = self.create_streaming_error_response(str(e))
@@ -247,54 +400,71 @@ class OpenAIServingChat(OpenAIServing):
         yield "data: [DONE]\n\n"
 
     async def chat_completion_full_generator(
-            self, request: ChatCompletionRequest, raw_request: Request,
-            result_generator: AsyncIterator[RequestOutput],
-            request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
-
-        model_name = request.model
+        self,
+        request: ChatCompletionRequest,
+        raw_request: Optional[Request],
+        result_generator: AsyncIterator[RequestOutput],
+        request_id: str,
+        conversation: List[ConversationMessage],
+        tokenizer: PreTrainedTokenizer,
+    ) -> Union[ErrorResponse, ChatCompletionResponse]:
+
+        model_name = self.served_model_names[0]
         created_time = int(time.time())
-        final_res: RequestOutput = None
+        final_res: Optional[RequestOutput] = None
 
         async for res in result_generator:
-            if await raw_request.is_disconnected():
+            if raw_request is not None and await raw_request.is_disconnected():
                 # Abort the request if the client disconnects.
-                await self.engine.abort(request_id)
+                await self.async_engine_client.abort(request_id)
                 return self.create_error_response("Client disconnected")
             final_res = res
         assert final_res is not None
 
-        choices = []
+        choices: List[ChatCompletionResponseChoice] = []
 
         role = self.get_chat_request_role(request)
         for output in final_res.outputs:
             token_ids = output.token_ids
-            top_logprobs = output.logprobs
+            out_logprobs = output.logprobs
 
-            if request.logprobs:
-                logprobs = self._create_logprobs(
+            if request.logprobs and request.top_logprobs is not None:
+                assert out_logprobs is not None, "Did not output logprobs"
+                logprobs = self._create_chat_logprobs(
                     token_ids=token_ids,
-                    top_logprobs=top_logprobs,
-                    num_output_top_logprobs=request.logprobs,
+                    top_logprobs=out_logprobs,
+                    tokenizer=tokenizer,
+                    num_output_top_logprobs=request.top_logprobs,
                 )
             else:
                 logprobs = None
 
+            if request.tool_choice and type(
+                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
+                message = ChatMessage(
+                    role=role,
+                    content="",
+                    tool_calls=[
+                        ToolCall(function=FunctionCall(
+                            name=request.tool_choice.function.name,
+                            arguments=output.text))
+                    ])
+            elif not request.tool_choice or request.tool_choice == "none":
+                message = ChatMessage(role=role, content=output.text)
+
             choice_data = ChatCompletionResponseChoice(
                 index=output.index,
-                message=ChatMessage(role=role, content=output.text),
+                message=message,
                 logprobs=logprobs,
                 finish_reason=output.finish_reason,
-                stop_reason=output.stop_reason,
-            )
+                stop_reason=output.stop_reason)
             choices.append(choice_data)
 
         if request.echo:
             last_msg_content = ""
-            if request.messages and isinstance(
-                    request.messages, list) and request.messages[-1].get(
-                        "content") and request.messages[-1].get(
-                            "role") == role:
-                last_msg_content = request.messages[-1]["content"]
+            if conversation and conversation[-1].get(
+                    "content") and conversation[-1].get("role") == role:
+                last_msg_content = conversation[-1]["content"]
 
             for choice in choices:
                 full_message = last_msg_content + choice.message.content
@@ -318,20 +488,56 @@ class OpenAIServingChat(OpenAIServing):
 
         return response
 
-    def _load_chat_template(self, chat_template):
-        if chat_template is not None:
-            try:
-                with open(chat_template, "r") as f:
-                    self.tokenizer.chat_template = f.read()
-            except OSError:
-                # If opening a file fails, set chat template to be args to
-                # ensure we decode so our escape are interpreted correctly
-                self.tokenizer.chat_template = codecs.decode(
-                    chat_template, "unicode_escape")
-
-            logger.info("Using the supplied chat template.")
-        elif self.tokenizer.chat_template is not None:
-            logger.info("Using the default chat template")
-        else:
-            logger.warning(
-                "No chat template provided. Chat API will not work.")
+    def _get_top_logprobs(
+            self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
+            tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
+        return [
+            ChatCompletionLogProb(token=(token := self._get_decoded_token(
+                p[1],
+                p[0],
+                tokenizer,
+                return_as_token_id=self.return_tokens_as_token_ids)),
+                                  logprob=max(p[1].logprob, -9999.0),
+                                  bytes=list(
+                                      token.encode("utf-8", errors="replace")))
+            for i, p in enumerate(logprobs.items())
+            if top_logprobs and i < top_logprobs
+        ]
+
+    def _create_chat_logprobs(
+        self,
+        token_ids: GenericSequence[int],
+        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
+        tokenizer: PreTrainedTokenizer,
+        num_output_top_logprobs: Optional[int] = None,
+    ) -> ChatCompletionLogProbs:
+        """Create OpenAI-style logprobs."""
+
+        logprobs_content = []
+
+        for i, token_id in enumerate(token_ids):
+            step_top_logprobs = top_logprobs[i]
+            if step_top_logprobs is None:
+                token = tokenizer.decode(token_id)
+                if self.return_tokens_as_token_ids:
+                    token = f"token_id:{token_id}"
+                logprobs_content.append(
+                    ChatCompletionLogProbsContent(
+                        token=token,
+                        bytes=list(token.encode("utf-8", errors="replace"))))
+            else:
+                logprobs_content.append(
+                    ChatCompletionLogProbsContent(
+                        token=self._get_decoded_token(
+                            step_top_logprobs[token_id], token_id, tokenizer,
+                            self.return_tokens_as_token_ids),
+                        logprob=max(step_top_logprobs[token_id].logprob,
+                                    -9999.0),
+                        bytes=list(
+                            step_top_logprobs[token_id].decoded_token.encode(
+                                "utf-8", errors="replace")),
+                        top_logprobs=self._get_top_logprobs(
+                            step_top_logprobs, num_output_top_logprobs,
+                            tokenizer)))
+
+        return ChatCompletionLogProbs(content=logprobs_content)

+ 236 - 132
aphrodite/endpoints/openai/serving_completions.py

@@ -1,70 +1,55 @@
 import time
-from typing import (
-    AsyncGenerator,
-    AsyncIterator,
-    Callable,
-    Dict,
-    List,
-    Optional,
-    Tuple,
-)
+from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
+                    Optional)
+from typing import Sequence as GenericSequence
+from typing import Tuple, cast
 
 from fastapi import Request
+from transformers import PreTrainedTokenizer
 
+from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.sequence import Logprob
 from aphrodite.common.utils import merge_async_iterators, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
+# yapf conflicts with isort for this block
+# yapf: disable
 from aphrodite.endpoints.openai.protocol import (
-    CompletionRequest,
-    CompletionResponse,
-    CompletionResponseChoice,
-    CompletionResponseStreamChoice,
-    CompletionStreamResponse,
-    LogProbs,
-    UsageInfo,
-)
-from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.modeling.guided_decoding import (
-    get_guided_decoding_logits_processor)
+    CompletionLogProbs, CompletionRequest, CompletionResponse,
+    CompletionResponseChoice, CompletionResponseStreamChoice,
+    CompletionStreamResponse, UsageInfo)
+# yapf: enable
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       OpenAIServing,
+                                                       PromptAdapterPath)
+from aphrodite.engine.protocol import AsyncEngineClient
 
 TypeTokenIDs = List[int]
 TypeTopLogProbs = List[Optional[Dict[int, float]]]
 TypeCreateLogProbsFn = Callable[
-    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
-
-
-def parse_prompt_format(prompt) -> Tuple[bool, list]:
-    # get the prompt, openai supports the following
-    # "a string, array of strings, array of tokens, or array of token arrays."
-    prompt_is_tokens = False
-    prompts = [prompt]  # case 1: a string
-    if isinstance(prompt, list):
-        if len(prompt) == 0:
-            raise ValueError("please provide at least one prompt")
-        elif isinstance(prompt[0], str):
-            prompt_is_tokens = False
-            prompts = prompt  # case 2: array of strings
-        elif isinstance(prompt[0], int):
-            prompt_is_tokens = True
-            prompts = [prompt]  # case 3: array of tokens
-        elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
-            prompt_is_tokens = True
-            prompts = prompt  # case 4: array of token arrays
-        else:
-            raise ValueError("prompt must be a string, array of strings, "
-                             "array of tokens, or array of token arrays")
-    return prompt_is_tokens, prompts
+    [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
 
 
 class OpenAIServingCompletion(OpenAIServing):
 
-    def __init__(self,
-                 engine: AsyncAphrodite,
-                 served_model: str,
-                 lora_modules: Optional[List[LoRA]] = None):
-        super().__init__(engine=engine,
-                         served_model=served_model,
-                         lora_modules=lora_modules)
+    def __init__(
+        self,
+        async_engine_client: AsyncEngineClient,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        return_tokens_as_token_ids: bool = False,
+    ):
+        super().__init__(async_engine_client=async_engine_client,
+                         model_config=model_config,
+                         served_model_names=served_model_names,
+                         lora_modules=lora_modules,
+                         prompt_adapters=prompt_adapters,
+                         request_logger=request_logger,
+                         return_tokens_as_token_ids=return_tokens_as_token_ids)
 
     async def create_completion(self, request: CompletionRequest,
                                 raw_request: Request):
@@ -86,51 +71,58 @@ class OpenAIServingCompletion(OpenAIServing):
             return self.create_error_response(
                 "suffix is not currently supported")
 
-        model_name = request.model
+        model_name = self.served_model_names[0]
         request_id = f"cmpl-{random_uuid()}"
         created_time = int(time.time())
 
         # Schedule the request and get the result generator.
-        generators = []
+        generators: List[AsyncIterator[RequestOutput]] = []
         try:
-            sampling_params = request.to_sampling_params(
-                self.tokenizer.vocab_size)
-            lora_request = self._maybe_get_lora(request)
-            decoding_config = self.engine.engine.decoding_config
-            guided_decoding_backend = request.guided_decoding_backend \
-                or decoding_config.guided_decoding_backend
-            guided_decode_logit_processor = (
-                await get_guided_decoding_logits_processor(
-                    guided_decoding_backend, request, await
-                    self.engine.get_tokenizer()))
-            if guided_decode_logit_processor is not None:
-                sampling_params.logits_processors.append(
-                    guided_decode_logit_processor)
-            prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
-
-            for i, prompt in enumerate(prompts):
-                if prompt_is_tokens:
-                    prompt_formats = self._validate_prompt_and_tokenize(
-                        request,
-                        prompt_ids=prompt,
-                        truncate_prompt_tokens=sampling_params.
-                        truncate_prompt_tokens)
-                else:
-                    prompt_formats = self._validate_prompt_and_tokenize(
-                        request,
-                        prompt=prompt,
-                        truncate_prompt_tokens=sampling_params.
-                        truncate_prompt_tokens)
-                prompt_ids, prompt_text = prompt_formats
-
-                generators.append(
-                    self.engine.generate(prompt_text,
-                                         sampling_params,
-                                         f"{request_id}-{i}",
-                                         prompt_token_ids=prompt_ids,
-                                         lora_request=lora_request))
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
+            tokenizer = await self.async_engine_client.get_tokenizer(
+                lora_request)
+
+            guided_decode_logits_processor = (
+                await self._guided_decode_logits_processor(request, tokenizer))
+            prompts = list(
+                self._tokenize_prompt_input_or_inputs(
+                    request,
+                    tokenizer,
+                    request.prompt,
+                    truncate_prompt_tokens=request.truncate_prompt_tokens,
+                    add_special_tokens=request.add_special_tokens,
+                ))
+
+            for i, prompt_inputs in enumerate(prompts):
+                sampling_params = request.to_sampling_params(
+                    tokenizer,
+                    guided_decode_logits_processor,
+                    default_max_tokens=self.max_model_len -
+                    len(prompt_inputs["prompt_token_ids"]))
+
+                request_id_item = f"{request_id}-{i}"
+
+                self._log_inputs(request_id_item,
+                                 prompt_inputs,
+                                 params=sampling_params,
+                                 lora_request=lora_request,
+                                 prompt_adapter_request=prompt_adapter_request)
+
+                generator = self.async_engine_client.generate(
+                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
+                    sampling_params,
+                    request_id_item,
+                    lora_request=lora_request,
+                    prompt_adapter_request=prompt_adapter_request,
+                )
+
+                generators.append(generator)
         except ValueError as e:
-            # TODO: Use a specific-specific Validation Error
+            # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
 
         result_generator: AsyncIterator[Tuple[
@@ -151,21 +143,40 @@ class OpenAIServingCompletion(OpenAIServing):
                                                     request_id,
                                                     created_time,
                                                     model_name,
-                                                    num_prompts=len(prompts))
+                                                    num_prompts=len(prompts),
+                                                    tokenizer=tokenizer)
 
         # Non-streaming response
-        final_res_batch: RequestOutput = [None] * len(prompts)
+        final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
         try:
             async for i, res in result_generator:
                 if await raw_request.is_disconnected():
                     # Abort the request if the client disconnects.
-                    await self.engine.abort(f"{request_id}-{i}")
+                    await self.async_engine_client.abort(f"{request_id}-{i}")
                     return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
+
+            for i, final_res in enumerate(final_res_batch):
+                assert final_res is not None
+
+                # The output should contain the input text
+                # We did not pass it into Aphrodite engine to avoid being
+                # redundant with the inputs token IDs
+                if final_res.prompt is None:
+                    final_res.prompt = prompts[i]["prompt"]
+
+            final_res_batch_checked = cast(List[RequestOutput],
+                                           final_res_batch)
             response = self.request_output_to_completion_response(
-                final_res_batch, request, request_id, created_time, model_name)
+                final_res_batch_checked,
+                request,
+                request_id,
+                created_time,
+                model_name,
+                tokenizer,
+            )
         except ValueError as e:
-            # TODO: Use a aphrodite-specific Validation Error
+            # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
 
         # When user requests streaming but we don't stream, we still need to
@@ -190,29 +201,33 @@ class OpenAIServingCompletion(OpenAIServing):
         created_time: int,
         model_name: str,
         num_prompts: int,
+        tokenizer: PreTrainedTokenizer,
     ) -> AsyncGenerator[str, None]:
-        previous_texts = [""] * request.n * num_prompts
-        previous_num_tokens = [0] * request.n * num_prompts
-        has_echoed = [False] * request.n * num_prompts
+        num_choices = 1 if request.n is None else request.n
+        previous_texts = [""] * num_choices * num_prompts
+        previous_num_tokens = [0] * num_choices * num_prompts
+        has_echoed = [False] * num_choices * num_prompts
 
         try:
             async for prompt_idx, res in result_generator:
 
                 # Abort the request if the client disconnects.
                 if await raw_request.is_disconnected():
-                    await self.engine.abort(f"{request_id}-{prompt_idx}")
+                    await self.async_engine_client.abort(
+                        f"{request_id}-{prompt_idx}")
                     raise StopAsyncIteration()
 
                 for output in res.outputs:
-                    i = output.index + prompt_idx * request.n
+                    i = output.index + prompt_idx * num_choices
                     # TODO: optimize the performance by avoiding full
                     # text O(n^2) sending.
 
+                    assert request.max_tokens is not None
                     if request.echo and request.max_tokens == 0:
                         # only return the prompt
                         delta_text = res.prompt
                         delta_token_ids = res.prompt_token_ids
-                        top_logprobs = res.prompt_logprobs
+                        out_logprobs = res.prompt_logprobs
                         has_echoed[i] = True
                     elif (request.echo and request.max_tokens > 0
                           and not has_echoed[i]):
@@ -220,7 +235,7 @@ class OpenAIServingCompletion(OpenAIServing):
                         delta_text = res.prompt + output.text
                         delta_token_ids = (res.prompt_token_ids +
                                            output.token_ids)
-                        top_logprobs = res.prompt_logprobs + (output.logprobs
+                        out_logprobs = res.prompt_logprobs + (output.logprobs
                                                               or [])
                         has_echoed[i] = True
                     else:
@@ -228,14 +243,17 @@ class OpenAIServingCompletion(OpenAIServing):
                         delta_text = output.text[len(previous_texts[i]):]
                         delta_token_ids = output.token_ids[
                             previous_num_tokens[i]:]
-                        top_logprobs = output.logprobs[previous_num_tokens[
+                        out_logprobs = output.logprobs[previous_num_tokens[
                             i]:] if output.logprobs else None
 
                     if request.logprobs is not None:
-                        logprobs = self._create_logprobs(
+                        assert out_logprobs is not None, (
+                            "Did not output logprobs")
+                        logprobs = self._create_completion_logprobs(
                             token_ids=delta_token_ids,
-                            top_logprobs=top_logprobs,
+                            top_logprobs=out_logprobs,
                             num_output_top_logprobs=request.logprobs,
+                            tokenizer=tokenizer,
                             initial_text_offset=len(previous_texts[i]),
                         )
                     else:
@@ -245,17 +263,8 @@ class OpenAIServingCompletion(OpenAIServing):
                     previous_num_tokens[i] = len(output.token_ids)
                     finish_reason = output.finish_reason
                     stop_reason = output.stop_reason
-                    if output.finish_reason is not None:  # return final usage
-                        prompt_tokens = len(res.prompt_token_ids)
-                        completion_tokens = len(output.token_ids)
-                        final_usage = UsageInfo(
-                            prompt_tokens=prompt_tokens,
-                            completion_tokens=completion_tokens,
-                            total_tokens=prompt_tokens + completion_tokens,
-                        )
-                    else:
-                        final_usage = None
-                    response_json = CompletionStreamResponse(
+
+                    chunk = CompletionStreamResponse(
                         id=request_id,
                         created=created_time,
                         model=model_name,
@@ -267,10 +276,39 @@ class OpenAIServingCompletion(OpenAIServing):
                                 finish_reason=finish_reason,
                                 stop_reason=stop_reason,
                             )
-                        ],
-                        usage=final_usage,
-                    ).model_dump_json(exclude_unset=True)
+                        ])
+                    if (request.stream_options
+                            and request.stream_options.include_usage):
+                        if (request.stream_options.continuous_usage_stats
+                                or output.finish_reason is not None):
+                            prompt_tokens = len(res.prompt_token_ids)
+                            completion_tokens = len(output.token_ids)
+                            usage = UsageInfo(
+                                prompt_tokens=prompt_tokens,
+                                completion_tokens=completion_tokens,
+                                total_tokens=prompt_tokens + completion_tokens,
+                            )
+                        if request.stream_options.continuous_usage_stats:
+                            chunk.usage = usage
+                        else:
+                            chunk.usage = None
+
+                    response_json = chunk.model_dump_json(exclude_unset=False)
                     yield f"data: {response_json}\n\n"
+
+            if (request.stream_options
+                    and request.stream_options.include_usage):
+                final_usage_chunk = CompletionStreamResponse(
+                    id=request_id,
+                    created=created_time,
+                    model=model_name,
+                    choices=[],
+                    usage=usage,
+                )
+                final_usage_data = (final_usage_chunk.model_dump_json(
+                    exclude_unset=False, exclude_none=True))
+                yield f"data: {final_usage_data}\n\n"
+
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             data = self.create_streaming_error_response(str(e))
@@ -284,38 +322,38 @@ class OpenAIServingCompletion(OpenAIServing):
         request_id: str,
         created_time: int,
         model_name: str,
+        tokenizer: PreTrainedTokenizer,
     ) -> CompletionResponse:
-        choices = []
+        choices: List[CompletionResponseChoice] = []
         num_prompt_tokens = 0
         num_generated_tokens = 0
         for final_res in final_res_batch:
-            assert final_res is not None
             prompt_token_ids = final_res.prompt_token_ids
             prompt_logprobs = final_res.prompt_logprobs
             prompt_text = final_res.prompt
 
             for output in final_res.outputs:
+                assert request.max_tokens is not None
                 if request.echo and request.max_tokens == 0:
                     token_ids = prompt_token_ids
-                    top_logprobs = prompt_logprobs
+                    out_logprobs = prompt_logprobs
                     output_text = prompt_text
                 elif request.echo and request.max_tokens > 0:
-                    token_ids = prompt_token_ids + output.token_ids
-                    top_logprobs = (prompt_logprobs + output.logprobs
-                                    if request.logprobs else None)
+                    token_ids = prompt_token_ids + list(output.token_ids)
+                    out_logprobs = (prompt_logprobs + output.logprobs
+                                    if request.logprobs is not None else None)
                     output_text = prompt_text + output.text
                 else:
                     token_ids = output.token_ids
-                    top_logprobs = output.logprobs
+                    out_logprobs = output.logprobs
                     output_text = output.text
 
                 if request.logprobs is not None:
-                    assert top_logprobs is not None, (
-                        "top_logprobs must be provided when logprobs "
-                        "is requested")
-                    logprobs = self._create_logprobs(
+                    assert out_logprobs is not None, "Did not output logprobs"
+                    logprobs = self._create_completion_logprobs(
                         token_ids=token_ids,
-                        top_logprobs=top_logprobs,
+                        top_logprobs=out_logprobs,
+                        tokenizer=tokenizer,
                         num_output_top_logprobs=request.logprobs,
                     )
                 else:
@@ -347,3 +385,69 @@ class OpenAIServingCompletion(OpenAIServing):
             choices=choices,
             usage=usage,
         )
+
+    def _create_completion_logprobs(
+        self,
+        token_ids: GenericSequence[int],
+        top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
+        num_output_top_logprobs: int,
+        tokenizer: PreTrainedTokenizer,
+        initial_text_offset: int = 0,
+    ) -> CompletionLogProbs:
+        """Create logprobs for OpenAI Completion API."""
+        out_text_offset: List[int] = []
+        out_token_logprobs: List[Optional[float]] = []
+        out_tokens: List[str] = []
+        out_top_logprobs: List[Optional[Dict[str, float]]] = []
+
+        last_token_len = 0
+
+        for i, token_id in enumerate(token_ids):
+            step_top_logprobs = top_logprobs[i]
+            if step_top_logprobs is None:
+                token = tokenizer.decode(token_id)
+                if self.return_tokens_as_token_ids:
+                    token = f"token_id:{token_id}"
+                out_tokens.append(token)
+                out_token_logprobs.append(None)
+                out_top_logprobs.append(None)
+            else:
+                token = self._get_decoded_token(
+                    step_top_logprobs[token_id],
+                    token_id,
+                    tokenizer,
+                    return_as_token_id=self.return_tokens_as_token_ids)
+                token_logprob = max(step_top_logprobs[token_id].logprob,
+                                    -9999.0)
+                out_tokens.append(token)
+                out_token_logprobs.append(token_logprob)
+
+                # makes sure to add the top num_output_top_logprobs + 1
+                # logprobs, as defined in the openai API
+                # (cf. https://github.com/openai/openai-openapi/blob/
+                # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
+                out_top_logprobs.append({
+                    # Convert float("-inf") to the
+                    # JSON-serializable float that OpenAI uses
+                    self._get_decoded_token(
+                        top_lp[1],
+                        top_lp[0],
+                        tokenizer,
+                        return_as_token_id=self.return_tokens_as_token_ids):
+                    max(top_lp[1].logprob, -9999.0)
+                    for i, top_lp in enumerate(step_top_logprobs.items())
+                    if num_output_top_logprobs >= i
+                })
+
+            if len(out_text_offset) == 0:
+                out_text_offset.append(initial_text_offset)
+            else:
+                out_text_offset.append(out_text_offset[-1] + last_token_len)
+            last_token_len = len(token)
+
+        return CompletionLogProbs(
+            text_offset=out_text_offset,
+            token_logprobs=out_token_logprobs,
+            tokens=out_tokens,
+            top_logprobs=out_top_logprobs,
+        )

+ 171 - 0
aphrodite/endpoints/openai/serving_embedding.py

@@ -0,0 +1,171 @@
+import base64
+import time
+from typing import AsyncIterator, List, Optional, Tuple, cast
+
+import numpy as np
+from fastapi import Request
+from loguru import logger
+
+from aphrodite.common.config import ModelConfig
+from aphrodite.common.outputs import EmbeddingRequestOutput
+from aphrodite.common.utils import merge_async_iterators, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
+from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
+                                                 EmbeddingResponse,
+                                                 EmbeddingResponseData,
+                                                 UsageInfo)
+from aphrodite.endpoints.openai.serving_engine import OpenAIServing
+from aphrodite.engine.protocol import AsyncEngineClient
+
+TypeTokenIDs = List[int]
+
+
+def request_output_to_embedding_response(
+        final_res_batch: List[EmbeddingRequestOutput], request_id: str,
+        created_time: int, model_name: str,
+        encoding_format: str) -> EmbeddingResponse:
+    data: List[EmbeddingResponseData] = []
+    num_prompt_tokens = 0
+    for idx, final_res in enumerate(final_res_batch):
+        prompt_token_ids = final_res.prompt_token_ids
+        embedding = final_res.outputs.embedding
+        if encoding_format == "base64":
+            embedding_bytes = np.array(embedding).tobytes()
+            embedding = base64.b64encode(embedding_bytes).decode("utf-8")
+        embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
+        data.append(embedding_data)
+
+        num_prompt_tokens += len(prompt_token_ids)
+
+    usage = UsageInfo(
+        prompt_tokens=num_prompt_tokens,
+        total_tokens=num_prompt_tokens,
+    )
+
+    return EmbeddingResponse(
+        id=request_id,
+        created=created_time,
+        model=model_name,
+        data=data,
+        usage=usage,
+    )
+
+
+class OpenAIServingEmbedding(OpenAIServing):
+
+    def __init__(
+        self,
+        async_engine_client: AsyncEngineClient,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        request_logger: Optional[RequestLogger],
+    ):
+        super().__init__(async_engine_client=async_engine_client,
+                         model_config=model_config,
+                         served_model_names=served_model_names,
+                         lora_modules=None,
+                         prompt_adapters=None,
+                         request_logger=request_logger)
+        self._check_embedding_mode(model_config.embedding_mode)
+
+    async def create_embedding(self, request: EmbeddingRequest,
+                               raw_request: Request):
+        """Completion API similar to OpenAI's API.
+
+        See https://platform.openai.com/docs/api-reference/embeddings/create
+        for the API specification. This API mimics the OpenAI Embedding API.
+        """
+        error_check_ret = await self._check_model(request)
+        if error_check_ret is not None:
+            return error_check_ret
+
+        encoding_format = (request.encoding_format
+                           if request.encoding_format else "float")
+        if request.dimensions is not None:
+            return self.create_error_response(
+                "dimensions is currently not supported")
+
+        model_name = request.model
+        request_id = f"embd-{random_uuid()}"
+        created_time = int(time.monotonic())
+
+        # Schedule the request and get the result generator.
+        generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
+        try:
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
+            tokenizer = await self.async_engine_client.get_tokenizer(
+                lora_request)
+            pooling_params = request.to_pooling_params()
+
+            prompts = list(
+                self._tokenize_prompt_input_or_inputs(
+                    request,
+                    tokenizer,
+                    request.input,
+                ))
+
+            for i, prompt_inputs in enumerate(prompts):
+                request_id_item = f"{request_id}-{i}"
+
+                self._log_inputs(request_id_item,
+                                 prompt_inputs,
+                                 params=pooling_params,
+                                 lora_request=lora_request,
+                                 prompt_adapter_request=prompt_adapter_request)
+
+                if prompt_adapter_request is not None:
+                    raise NotImplementedError(
+                        "Prompt adapter is not supported "
+                        "for embedding models")
+
+                generator = self.async_engine_client.encode(
+                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
+                    pooling_params,
+                    request_id_item,
+                    lora_request=lora_request,
+                )
+
+                generators.append(generator)
+        except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
+            return self.create_error_response(str(e))
+
+        result_generator: AsyncIterator[Tuple[
+            int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
+
+        # Non-streaming response
+        final_res_batch: List[Optional[EmbeddingRequestOutput]]
+        final_res_batch = [None] * len(prompts)
+        try:
+            async for i, res in result_generator:
+                if await raw_request.is_disconnected():
+                    # Abort the request if the client disconnects.
+                    await self.async_engine_client.abort(f"{request_id}-{i}")
+                    return self.create_error_response("Client disconnected")
+                final_res_batch[i] = res
+
+            for final_res in final_res_batch:
+                assert final_res is not None
+
+            final_res_batch_checked = cast(List[EmbeddingRequestOutput],
+                                           final_res_batch)
+            response = request_output_to_embedding_response(
+                final_res_batch_checked, request_id, created_time, model_name,
+                encoding_format)
+        except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
+            return self.create_error_response(str(e))
+
+        return response
+
+    def _check_embedding_mode(self, embedding_mode: bool):
+        if not embedding_mode:
+            logger.warning(
+                "embedding_mode is False. Embedding API will not work.")
+        else:
+            logger.info("Activating the server engine with embedding enabled.")

+ 328 - 174
aphrodite/endpoints/openai/serving_engine.py

@@ -1,157 +1,135 @@
-import asyncio
 import json
+import pathlib
 from dataclasses import dataclass
 from http import HTTPStatus
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
 
-from loguru import logger
-from pydantic import conint
+from pydantic import Field
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+from typing_extensions import Annotated
 
+from aphrodite.common.config import ModelConfig
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import (LogitsProcessorFunc,
+                                              SamplingParams)
 from aphrodite.common.sequence import Logprob
-from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest,
-    CompletionRequest,
-    ErrorResponse,
-    LogProbs,
-    ModelCard,
-    ModelList,
-    ModelPermission,
-    Prompt,
-)
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.endpoints.logger import RequestLogger
+# yapf conflicts with isort here
+# yapf: disable
+from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
+                                                 CompletionRequest,
+                                                 DetokenizeRequest,
+                                                 EmbeddingRequest,
+                                                 ErrorResponse, ModelCard,
+                                                 ModelList, ModelPermission,
+                                                 TokenizeChatRequest,
+                                                 TokenizeCompletionRequest,
+                                                 TokenizeRequest)
+# yapf: enable
+from aphrodite.engine.protocol import AsyncEngineClient
+from aphrodite.inputs import parse_and_batch_prompt
 from aphrodite.lora.request import LoRARequest
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
+from aphrodite.modeling.guided_decoding import (
+    get_guided_decoding_logits_processor)
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 
 @dataclass
-class LoRA:
+class PromptAdapterPath:
     name: str
     local_path: str
 
 
+@dataclass
+class LoRAModulePath:
+    name: str
+    path: str
+
+
+AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
+                   EmbeddingRequest, TokenizeRequest]
+
+AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
+
+
+class TextTokensPrompt(TypedDict):
+    prompt: str
+    prompt_token_ids: List[int]
+
+
 class OpenAIServing:
 
-    def __init__(self,
-                 engine: AsyncAphrodite,
-                 served_model: str,
-                 lora_modules=Optional[List[LoRA]]):
-        self.engine = engine
-        self.served_model = served_model
-        if lora_modules is None:
-            self.lora_requests = []
-        else:
+    def __init__(
+        self,
+        async_engine_client: AsyncEngineClient,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        return_tokens_as_token_ids: bool = False,
+    ):
+        super().__init__()
+
+        self.async_engine_client = async_engine_client
+        self.model_config = model_config
+        self.max_model_len = model_config.max_model_len
+
+        self.served_model_names = served_model_names
+
+        self.lora_requests = []
+        if lora_modules is not None:
             self.lora_requests = [
                 LoRARequest(
                     lora_name=lora.name,
                     lora_int_id=i,
-                    lora_local_path=lora.local_path,
+                    lora_path=lora.path,
                 ) for i, lora in enumerate(lora_modules, start=1)
             ]
 
-        self.max_model_len = 0
-        self.tokenizer = None
-
-        try:
-            event_loop = asyncio.get_running_loop()
-        except RuntimeError:
-            event_loop = None
-
-        if event_loop is not None and event_loop.is_running():
-            # If the current is instanced by Ray Serve,
-            # there is already a running event loop
-            event_loop.create_task(self._post_init())
-        else:
-            # When using single Aphrodite without engine_use_ray
-            asyncio.run(self._post_init())
+        self.prompt_adapter_requests = []
+        if prompt_adapters is not None:
+            for i, prompt_adapter in enumerate(prompt_adapters, start=1):
+                with pathlib.Path(prompt_adapter.local_path,
+                                  "adapter_config.json").open() as f:
+                    adapter_config = json.load(f)
+                    num_virtual_tokens = adapter_config["num_virtual_tokens"]
+                self.prompt_adapter_requests.append(
+                    PromptAdapterRequest(
+                        prompt_adapter_name=prompt_adapter.name,
+                        prompt_adapter_id=i,
+                        prompt_adapter_local_path=prompt_adapter.local_path,
+                        prompt_adapter_num_virtual_tokens=num_virtual_tokens))
 
-    async def _post_init(self):
-        engine_model_config = await self.engine.get_model_config()
-        self.max_model_len = engine_model_config.max_model_len
-
-        # A separate tokenizer to map token IDs to strings.
-        self.tokenizer = get_tokenizer(
-            engine_model_config.tokenizer,
-            tokenizer_mode=engine_model_config.tokenizer_mode,
-            trust_remote_code=engine_model_config.trust_remote_code,
-            revision=engine_model_config.revision,
-            truncation_side="left")
+        self.request_logger = request_logger
+        self.return_tokens_as_token_ids = return_tokens_as_token_ids
 
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
         model_cards = [
-            ModelCard(id=self.served_model,
-                      root=self.served_model,
+            ModelCard(id=served_model_name,
+                      max_model_len=self.max_model_len,
+                      root=self.served_model_names[0],
                       permission=[ModelPermission()])
+            for served_model_name in self.served_model_names
         ]
         lora_cards = [
             ModelCard(id=lora.lora_name,
-                      root=self.served_model,
+                      root=self.served_model_names[0],
                       permission=[ModelPermission()])
             for lora in self.lora_requests
         ]
+        prompt_adapter_cards = [
+            ModelCard(id=prompt_adapter.prompt_adapter_name,
+                      root=self.served_model_names[0],
+                      permission=[ModelPermission()])
+            for prompt_adapter in self.prompt_adapter_requests
+        ]
         model_cards.extend(lora_cards)
+        model_cards.extend(prompt_adapter_cards)
         return ModelList(data=model_cards)
 
-    async def tokenize(self, prompt: Prompt):
-        """Tokenize a given prompt."""
-        tokenized_prompt = self.tokenizer.tokenize(prompt.prompt)
-        token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_prompt)
-        return {"value": len(tokenized_prompt), "ids": token_ids}
-
-    async def detokenize(self, token_ids: List[int]):
-        """Detokenize a given list of token IDs."""
-        tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
-        detokenized_text = self.tokenizer.convert_tokens_to_string(tokens)
-        return {"value": detokenized_text}
-
-    def _create_logprobs(
-        self,
-        token_ids: List[int],
-        top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
-        num_output_top_logprobs: Optional[int] = None,
-        initial_text_offset: int = 0,
-    ) -> LogProbs:
-        """Create OpenAI-style logprobs."""
-        logprobs = LogProbs()
-        last_token_len = 0
-        if num_output_top_logprobs:
-            logprobs.top_logprobs = []
-
-        for i, token_id in enumerate(token_ids):
-            step_top_logprobs = top_logprobs[i]
-            if step_top_logprobs is None:
-                token = self.tokenizer.decode(token_id)
-                logprobs.tokens.append(token)
-                logprobs.token_logprobs.append(None)
-                logprobs.top_logprobs.append(None)
-            else:
-                token_logprob = step_top_logprobs[token_id].logprob
-                token = step_top_logprobs[token_id].decoded_token
-                logprobs.tokens.append(token)
-                logprobs.token_logprobs.append(token_logprob)
-
-                if num_output_top_logprobs:
-                    logprobs.top_logprobs.append({
-                        p.decoded_token: p.logprob
-                        for i, p in step_top_logprobs.items()
-                    } if step_top_logprobs else None)
-
-            if logprobs.top_logprobs:
-                logprobs.top_logprobs = [{
-                    k: v if v > -1000 else -1000
-                    for k, v in top_logprob.items()
-                } for top_logprob in logprobs.top_logprobs
-                                         if top_logprob is not None
-                                         ]  # noqa: E501
-
-            if len(logprobs.text_offset) == 0:
-                logprobs.text_offset.append(initial_text_offset)
-            else:
-                logprobs.text_offset.append(logprobs.text_offset[-1] +
-                                            last_token_len)
-            last_token_len = len(token)
-        return logprobs
-
     def create_error_response(
             self,
             message: str,
@@ -174,81 +152,257 @@ class OpenAIServing:
         })
         return json_str
 
-    async def _check_model(self, request) -> Optional[ErrorResponse]:
-        if request.model == self.served_model:
-            return
-        if request.model in [lora.lora_name for lora in self.lora_requests]:
-            return
-        return self.create_error_response(
-            message=f"The model `{request.model}` does not exist.",
-            err_type="NotFoundError",
-            status_code=HTTPStatus.NOT_FOUND)
-
-    def add_lora(self, lora: LoRA):
-        if lora.name in [
-                existing_lora.lora_name for existing_lora in self.lora_requests
-        ]:
-            logger.error(f"LoRA with name {lora.name} already exists.")
-            return
-        self.lora_requests.append(
-            LoRARequest(
-                lora_name=lora.name,
-                lora_int_id=len(self.lora_requests) + 1,
-                lora_local_path=lora.local_path,
-            ))
+    async def _guided_decode_logits_processor(
+            self, request: Union[ChatCompletionRequest, CompletionRequest],
+            tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFunc]:
+        decoding_config = await self.async_engine_client.get_decoding_config()
+        guided_decoding_backend = request.guided_decoding_backend \
+            or decoding_config.guided_decoding_backend
+        return await get_guided_decoding_logits_processor(
+            guided_decoding_backend, request, tokenizer)
 
-    def remove_lora(self, lora_name: str):
-        self.lora_requests = [
-            lora for lora in self.lora_requests if lora.lora_name != lora_name
-        ]
+    async def _check_model(
+        self,
+        request: AnyRequest,
+    ) -> Optional[ErrorResponse]:
+        # only check these if it's not a Tokenizer/Detokenize Request
+        if not isinstance(request, (TokenizeRequest, DetokenizeRequest)):
+            if request.model in self.served_model_names:
+                return None
+            if request.model in [
+                    lora.lora_name for lora in self.lora_requests
+            ]:
+                return None
+            if request.model in [
+                    prompt_adapter.prompt_adapter_name
+                    for prompt_adapter in self.prompt_adapter_requests
+            ]:
+                return None
+            return self.create_error_response(
+                message=f"The model `{request.model}` does not exist.",
+                err_type="NotFoundError",
+                status_code=HTTPStatus.NOT_FOUND)
 
-    def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
-        if request.model == self.served_model:
-            return
+    def _maybe_get_adapters(
+        self, request: AnyRequest
+    ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
+            None, PromptAdapterRequest]]:
+        if request.model in self.served_model_names:
+            return None, None
         for lora in self.lora_requests:
             if request.model == lora.lora_name:
-                return lora
+                return lora, None
+        for prompt_adapter in self.prompt_adapter_requests:
+            if request.model == prompt_adapter.prompt_adapter_name:
+                return None, prompt_adapter
         # if _check_model has been called earlier, this will be unreachable
-        raise ValueError("The model `{request.model}` does not exist.")
+        raise ValueError(f"The model `{request.model}` does not exist.")
 
-    def _validate_prompt_and_tokenize(
+    def _normalize_prompt_text_to_input(
         self,
-        request: Union[ChatCompletionRequest, CompletionRequest],
-        prompt: Optional[str] = None,
-        prompt_ids: Optional[List[int]] = None,
-        truncate_prompt_tokens: Optional[conint(ge=1)] = None
-    ) -> Tuple[List[int], str]:
-        if not (prompt or prompt_ids):
-            raise ValueError("Either prompt or prompt_ids should be provided.")
-        if (prompt and prompt_ids):
-            raise ValueError(
-                "Only one of prompt or prompt_ids should be provided.")
-
-        if prompt_ids is None:
-            tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
-                "truncation": True,
-                "max_length": truncate_prompt_tokens,
-            }
-            input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
-        elif truncate_prompt_tokens is not None:
-            input_ids = prompt_ids[-truncate_prompt_tokens:]
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt: str,
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+        add_special_tokens: bool,
+    ) -> TextTokensPrompt:
+        if truncate_prompt_tokens is None:
+            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
         else:
+            encoded = tokenizer(prompt,
+                                add_special_tokens=add_special_tokens,
+                                truncation=True,
+                                max_length=truncate_prompt_tokens)
+
+        input_ids = encoded.input_ids
+
+        input_text = prompt
+
+        return self._validate_input(request, input_ids, input_text)
+
+    def _normalize_prompt_tokens_to_input(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_ids: List[int],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+    ) -> TextTokensPrompt:
+        if truncate_prompt_tokens is None:
             input_ids = prompt_ids
+        else:
+            input_ids = prompt_ids[-truncate_prompt_tokens:]
+
+        input_text = tokenizer.decode(input_ids)
 
-        input_text = prompt if prompt is not None else self.tokenizer.decode(
-            prompt_ids)
+        return self._validate_input(request, input_ids, input_text)
+
+    def _validate_input(
+        self,
+        request: AnyRequest,
+        input_ids: List[int],
+        input_text: str,
+    ) -> TextTokensPrompt:
         token_num = len(input_ids)
 
-        if request.max_tokens is None:
-            request.max_tokens = self.max_model_len - token_num
+        # Note: EmbeddingRequest doesn't have max_tokens
+        if isinstance(request, EmbeddingRequest):
+            if token_num > self.max_model_len:
+                raise ValueError(
+                    f"This model's maximum context length is "
+                    f"{self.max_model_len} tokens. However, you requested "
+                    f"{token_num} tokens in the input for embedding "
+                    f"generation. Please reduce the length of the input.")
+            return TextTokensPrompt(prompt=input_text,
+                                    prompt_token_ids=input_ids)
+
+        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
+        # and does not require model context length validation
+        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
+                                DetokenizeRequest)):
+            return TextTokensPrompt(prompt=input_text,
+                                    prompt_token_ids=input_ids)
 
-        if token_num + request.max_tokens > self.max_model_len:
+        if request.max_tokens is None:
+            if token_num >= self.max_model_len:
+                raise ValueError(
+                    f"This model's maximum context length is "
+                    f"{self.max_model_len} tokens. However, you requested "
+                    f"{token_num} tokens in the messages, "
+                    f"Please reduce the length of the messages.")
+        elif token_num + request.max_tokens > self.max_model_len:
             raise ValueError(
                 f"This model's maximum context length is "
                 f"{self.max_model_len} tokens. However, you requested "
                 f"{request.max_tokens + token_num} tokens "
                 f"({token_num} in the messages, "
                 f"{request.max_tokens} in the completion). "
-                f"Please reduce the length of the messages or completion.", )
+                f"Please reduce the length of the messages or completion.")
+
+        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
+
+    def _tokenize_prompt_input(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_input: Union[str, List[int]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> TextTokensPrompt:
+        """
+        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+        that assumes single input.
+        """
+        return next(
+            self._tokenize_prompt_inputs(
+                request,
+                tokenizer,
+                [prompt_input],
+                truncate_prompt_tokens=truncate_prompt_tokens,
+                add_special_tokens=add_special_tokens,
+            ))
+
+    def _tokenize_prompt_inputs(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_inputs: Iterable[Union[str, List[int]]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> Iterator[TextTokensPrompt]:
+        """
+        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+        that assumes multiple inputs.
+        """
+        for text in prompt_inputs:
+            if isinstance(text, str):
+                yield self._normalize_prompt_text_to_input(
+                    request,
+                    tokenizer,
+                    prompt=text,
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                    add_special_tokens=add_special_tokens,
+                )
+            else:
+                yield self._normalize_prompt_tokens_to_input(
+                    request,
+                    tokenizer,
+                    prompt_ids=text,
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                )
+
+    def _tokenize_prompt_input_or_inputs(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> Iterator[TextTokensPrompt]:
+        """
+        Tokenize/detokenize depending on the input format.
+        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
+        , each input can be a string or array of tokens. Note that each request
+        can pass one or more inputs.
+        """
+        for prompt_input in parse_and_batch_prompt(input_or_inputs):
+            # Although our type checking is based on mypy,
+            # VSCode Pyright extension should still work properly
+            # "is True" is required for Pyright to perform type narrowing
+            # See: https://github.com/microsoft/pyright/issues/7672
+            if prompt_input["is_tokens"] is False:
+                yield self._normalize_prompt_text_to_input(
+                    request,
+                    tokenizer,
+                    prompt=prompt_input["content"],
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                    add_special_tokens=add_special_tokens,
+                )
+            else:
+                yield self._normalize_prompt_tokens_to_input(
+                    request,
+                    tokenizer,
+                    prompt_ids=prompt_input["content"],
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                )
+
+    def _log_inputs(
+        self,
+        request_id: str,
+        inputs: Union[str, List[int], TextTokensPrompt],
+        params: Optional[Union[SamplingParams, PoolingParams]],
+        lora_request: Optional[LoRARequest],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> None:
+        if self.request_logger is None:
+            return
+
+        if isinstance(inputs, str):
+            prompt = inputs
+            prompt_token_ids = None
+        elif isinstance(inputs, list):
+            prompt = None
+            prompt_token_ids = inputs
         else:
-            return input_ids, input_text
+            prompt = inputs["prompt"]
+            prompt_token_ids = inputs["prompt_token_ids"]
+
+        self.request_logger.log_inputs(
+            request_id,
+            prompt,
+            prompt_token_ids,
+            params=params,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+        )
+
+    @staticmethod
+    def _get_decoded_token(logprob: Logprob,
+                           token_id: int,
+                           tokenizer: AnyTokenizer,
+                           return_as_token_id: bool = False) -> str:
+        if return_as_token_id:
+            return f"token_id:{token_id}"
+
+        if logprob.decoded_token is not None:
+            return logprob.decoded_token
+        return tokenizer.decode(token_id)

+ 136 - 0
aphrodite/endpoints/openai/serving_tokenization.py

@@ -0,0 +1,136 @@
+from typing import List, Optional, Union
+
+from loguru import logger
+
+from aphrodite.common.config import ModelConfig
+from aphrodite.common.utils import random_uuid
+# yapf conflicts with isort
+# yapf: disable
+from aphrodite.endpoints.chat_utils import (load_chat_template,
+                                            parse_chat_messages)
+from aphrodite.endpoints.logger import RequestLogger
+from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
+                                                 DetokenizeResponse,
+                                                 ErrorResponse,
+                                                 TokenizeChatRequest,
+                                                 TokenizeRequest,
+                                                 TokenizeResponse)
+# yapf: enable
+from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
+                                                       OpenAIServing)
+from aphrodite.engine.protocol import AsyncEngineClient
+
+
+class OpenAIServingTokenization(OpenAIServing):
+
+    def __init__(
+        self,
+        async_engine_client: AsyncEngineClient,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+    ):
+        super().__init__(async_engine_client=async_engine_client,
+                         model_config=model_config,
+                         served_model_names=served_model_names,
+                         lora_modules=lora_modules,
+                         prompt_adapters=None,
+                         request_logger=request_logger)
+
+        # If this is None we use the tokenizer's default chat template
+        self.chat_template = load_chat_template(chat_template)
+
+    async def create_tokenize(
+        self,
+        request: TokenizeRequest,
+    ) -> Union[TokenizeResponse, ErrorResponse]:
+        error_check_ret = await self._check_model(request)
+        if error_check_ret is not None:
+            return error_check_ret
+
+        request_id = f"tokn-{random_uuid()}"
+
+        (
+            lora_request,
+            prompt_adapter_request,
+        ) = self._maybe_get_adapters(request)
+
+        tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
+
+        if isinstance(request, TokenizeChatRequest):
+            model_config = self.model_config
+
+            conversation, mm_futures = parse_chat_messages(
+                request.messages, model_config, tokenizer)
+
+            if mm_futures:
+                logger.warning(
+                    "Multi-modal inputs are ignored during tokenization")
+
+            prompt = tokenizer.apply_chat_template(
+                add_generation_prompt=request.add_generation_prompt,
+                conversation=conversation,
+                tokenize=False,
+                chat_template=self.chat_template)
+            assert isinstance(prompt, str)
+        else:
+            prompt = request.prompt
+
+        self._log_inputs(request_id,
+                         prompt,
+                         params=None,
+                         lora_request=lora_request,
+                         prompt_adapter_request=prompt_adapter_request)
+
+        # Silently ignore prompt adapter since it does not affect tokenization
+
+        prompt_input = self._tokenize_prompt_input(
+            request,
+            tokenizer,
+            prompt,
+            add_special_tokens=request.add_special_tokens,
+        )
+        input_ids = prompt_input["prompt_token_ids"]
+
+        return TokenizeResponse(tokens=input_ids,
+                                count=len(input_ids),
+                                max_model_len=self.max_model_len)
+
+    async def create_detokenize(
+        self,
+        request: DetokenizeRequest,
+    ) -> Union[DetokenizeResponse, ErrorResponse]:
+        error_check_ret = await self._check_model(request)
+        if error_check_ret is not None:
+            return error_check_ret
+
+        request_id = f"tokn-{random_uuid()}"
+
+        (
+            lora_request,
+            prompt_adapter_request,
+        ) = self._maybe_get_adapters(request)
+
+        tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
+
+        self._log_inputs(request_id,
+                         request.tokens,
+                         params=None,
+                         lora_request=lora_request,
+                         prompt_adapter_request=prompt_adapter_request)
+
+        if prompt_adapter_request is not None:
+            raise NotImplementedError("Prompt adapter is not supported "
+                                      "for tokenization")
+
+        prompt_input = self._tokenize_prompt_input(
+            request,
+            tokenizer,
+            request.tokens,
+        )
+        input_text = prompt_input["prompt"]
+
+        return DetokenizeResponse(prompt=input_text)

Fișier diff suprimat deoarece este prea mare
+ 664 - 225
aphrodite/engine/aphrodite_engine.py


Fișier diff suprimat deoarece este prea mare
+ 585 - 326
aphrodite/engine/args_tools.py


+ 449 - 189
aphrodite/engine/async_aphrodite.py

@@ -2,62 +2,76 @@ import asyncio
 import os
 import time
 from functools import partial
-from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
-                    Union, AsyncIterator)
+from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
+                    Set, Tuple, Type, Union)
+
 from loguru import logger
 from transformers import PreTrainedTokenizer
 
-from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import ModelConfig
-from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.engine.aphrodite_engine import AphroditeEngine
-from aphrodite.engine.ray_tools import initialize_ray_cluster, ray
-from aphrodite.common.outputs import RequestOutput
+from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig,
+                                     ModelConfig, ParallelConfig,
+                                     SchedulerConfig)
+from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
+from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import MultiModalData
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_timeout import asyncio_timeout
+from aphrodite.engine.metrics import StatLoggerBase
+from aphrodite.executor.executor_base import ExecutorAsyncBase
+from aphrodite.executor.ray_utils import initialize_ray_cluster, ray
+from aphrodite.inputs import LLMInputs, PromptInputs
+from aphrodite.lora.request import LoRARequest
+from aphrodite.processing.scheduler import SchedulerOutputs
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
 ENGINE_ITERATION_TIMEOUT_S = int(
-    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "120"))
+    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "60"))
 
 
 class AsyncEngineDeadError(RuntimeError):
     pass
 
 
-def _raise_exception_on_finish(
-        task: asyncio.Task, error_callback: Callable[[Exception],
-                                                     None]) -> None:
-    msg = ("Task finished unexpectedly. This should never happen! "
-           "Please open an issue on Github. Include your full error "
-           "log after killing the process with Ctrl+C.")
+def _log_task_completion(task: asyncio.Task,
+                         error_callback: Callable[[Exception], None]) -> None:
+    """This function is only intended for the `engine.run_engine_loop()` task.
+    In particular, that task runs a `while True` loop that can only exit if
+    there is an exception.
+    """
 
     exception = None
     try:
-        task.result()
-        # NOTE: This will be thrown if task exits normally (which it should not)
-        raise AsyncEngineDeadError(msg)
+        return_value = task.result()
+        raise AssertionError(
+            f"The engine background task should never finish without an "
+            f"exception. {return_value}")
     except asyncio.exceptions.CancelledError:
-        pass
-    except KeyboardInterrupt:
-        raise
+        # We assume that if the task is cancelled, we are gracefully shutting
+        # down. This should only happen on program exit.
+        logger.info("Engine is gracefully shutting down.")
     except Exception as e:
         exception = e
         logger.error("Engine background task failed", exc_info=e)
         error_callback(exception)
         raise AsyncEngineDeadError(
-            msg + " See stack trace above for the actual cause.") from e
+            "Task finished unexpectedly. This should never happen! "
+            "Please open an issue on Github. See stack trace above for the"
+            "actual cause.") from e
 
 
 class AsyncStream:
-    """A stream of RequestOutputs for a request that can be
-    iterated over asynchronously."""
+    """A stream of RequestOutputs or EmbeddingRequestOutputs for a request
+    that can be iterated over asynchronously."""
 
     def __init__(self, request_id: str) -> None:
         self.request_id = request_id
-        self._queue = asyncio.Queue()
+        self._queue: asyncio.Queue = asyncio.Queue()
         self._finished = False
 
-    def put(self, item: Union[RequestOutput, Exception]) -> None:
+    def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
+                              Exception]) -> None:
         if self._finished:
             return
         self._queue.put_nowait(item)
@@ -73,7 +87,7 @@ class AsyncStream:
     def __aiter__(self):
         return self
 
-    async def __anext__(self) -> RequestOutput:
+    async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
         result = await self._queue.get()
         if isinstance(result, Exception):
             raise result
@@ -110,13 +124,17 @@ class RequestTracker:
                 self.abort_request(rid)
 
     def process_request_output(self,
-                               request_output: RequestOutput,
+                               request_output: Union[RequestOutput,
+                                                     EmbeddingRequestOutput],
                                *,
                                verbose: bool = False) -> None:
         """Process a request output from the engine."""
         request_id = request_output.request_id
 
-        self._request_streams[request_id].put(request_output)
+        # Guard against a KeyError which can occur if the request was aborted
+        # while the output was generated
+        if (stream := self._request_streams.get(request_id)) is not None:
+            stream.put(request_output)
         if request_output.finished:
             if verbose:
                 logger.info(f"Finished request {request_id}.")
@@ -133,7 +151,10 @@ class RequestTracker:
             logger.info(f"Finished request {request_id}.")
         self.abort_request(request_id)
 
-    def add_request(self, request_id: str,
+    def add_request(self,
+                    request_id: str,
+                    *,
+                    verbose: bool = False,
                     **engine_add_request_kwargs) -> AsyncStream:
         """Add a request to be sent to the engine on the next background
         loop iteration."""
@@ -148,6 +169,9 @@ class RequestTracker:
 
         self.new_requests_event.set()
 
+        if verbose:
+            logger.info(f"Added request {request_id}.")
+
         return stream
 
     def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -198,7 +222,9 @@ class RequestTracker:
 class _AsyncAphrodite(AphroditeEngine):
     """Extension of AphroditeEngine to add async methods."""
 
-    async def step_async(self) -> List[RequestOutput]:
+    async def step_async(
+        self, virtual_engine: int
+    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
         """Performs one decoding iteration and returns newly generated results.
         The workers are ran asynchronously if possible.
 
@@ -208,73 +234,107 @@ class _AsyncAphrodite(AphroditeEngine):
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         """
-        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
+        seq_group_metadata_list, scheduler_outputs = self.scheduler[
+            virtual_engine].schedule()
 
         if not scheduler_outputs.is_empty():
             # Execute the model.
+            finished_requests_ids = self.scheduler[
+                virtual_engine].get_and_reset_finished_requests_ids()
+            execute_model_req = ExecuteModelRequest(
+                seq_group_metadata_list=seq_group_metadata_list,
+                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
+                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
+                blocks_to_copy=scheduler_outputs.blocks_to_copy,
+                virtual_engine=virtual_engine,
+                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
+                running_queue_size=scheduler_outputs.running_queue_size,
+                finished_requests_ids=finished_requests_ids,
+            )
             output = await self.model_executor.execute_model_async(
-                seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
-                scheduler_outputs.blocks_to_swap_out,
-                scheduler_outputs.blocks_to_copy,
-                scheduler_outputs.num_lookahead_slots)
+                execute_model_req)
         else:
             output = []
 
         request_outputs = self._process_model_outputs(
             output, scheduler_outputs.scheduled_seq_groups,
-            scheduler_outputs.ignored_seq_groups)
+            scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
 
         # Log stats.
-        if self.log_stats:
-            self.stat_logger.log(self._get_stats(scheduler_outputs))
+        self.do_log_stats(scheduler_outputs, output)
 
         return request_outputs
 
-    async def encode_request_async(
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        """Stop the remote worker execution loop."""
+        await self.model_executor.stop_remote_worker_execution_loop_async()
+
+    async def process_model_inputs_async(
         self,
-        request_id: str,  # pylint: disable=unused-argument
-        prompt: Optional[str],
-        prompt_token_ids: Optional[List[int]] = None,
+        request_id: str,
+        inputs: PromptInputs,
         lora_request: Optional[LoRARequest] = None,
-    ):
-        if prompt_token_ids is None:
-            assert prompt is not None
-            prompt_token_ids = await self.tokenizer.encode_async(
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+    ) -> LLMInputs:
+        if isinstance(inputs, str):
+            inputs = {"prompt": inputs}
+
+        if "prompt_token_ids" not in inputs:
+            tokenizer = self.get_tokenizer_group("prompts must be None if "
+                                                 "skip_tokenizer_init is True")
+
+            prompt_token_ids = await tokenizer.encode_async(
                 request_id=request_id,
-                prompt=prompt,
+                prompt=inputs["prompt"],
                 lora_request=lora_request)
-        return prompt_token_ids
+        else:
+            prompt_token_ids = inputs["prompt_token_ids"]
+
+        if prompt_adapter_request:
+            prompt_token_ids = [
+                0
+            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
+                prompt_token_ids
+
+        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
+                               prompt=inputs.get("prompt"),
+                               multi_modal_data=inputs.get("multi_modal_data"))
+
+        return self.input_processor(llm_inputs)
 
     async def add_request_async(
         self,
         request_id: str,
-        prompt: Optional[str],
-        sampling_params: SamplingParams,
-        prompt_token_ids: Optional[List[int]] = None,
+        inputs: PromptInputs,
+        params: Union[SamplingParams, PoolingParams],
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                              "not enabled!")
         if arrival_time is None:
             arrival_time = time.time()
-        prompt_token_ids = await self.encode_request_async(
+
+        processed_inputs = await self.process_model_inputs_async(
             request_id=request_id,
-            prompt=prompt,
-            prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request)
-
-        return self.add_request(request_id,
-                                prompt=prompt,
-                                prompt_token_ids=prompt_token_ids,
-                                sampling_params=sampling_params,
-                                arrival_time=arrival_time,
-                                lora_request=lora_request,
-                                multi_modal_data=multi_modal_data)
+            inputs=inputs,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request)
+
+        self._add_processed_request(
+            request_id=request_id,
+            processed_inputs=processed_inputs,
+            params=params,
+            arrival_time=arrival_time,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
     async def check_health_async(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
         self.model_executor.check_health()
 
 
@@ -298,8 +358,6 @@ class AsyncAphrodite:
             async frontend will be executed in a separate process as the
             model workers.
         log_requests: Whether to log the requests.
-        max_log_len: Maximum number of prompt characters or prompt ID numbers
-            being printed in log.
         start_engine_loop: If True, the background task to run the engine
             will be automatically started in the generate call.
         *args: Arguments for AphroditeEngine.
@@ -313,66 +371,124 @@ class AsyncAphrodite:
                  engine_use_ray: bool,
                  *args,
                  log_requests: bool = True,
-                 max_log_len: int = 0,
                  start_engine_loop: bool = True,
                  **kwargs) -> None:
         self.worker_use_ray = worker_use_ray
         self.engine_use_ray = engine_use_ray
         self.log_requests = log_requests
-        self.max_log_len = max_log_len
         self.engine = self._init_engine(*args, **kwargs)
 
-        self.background_loop = None
+        self.background_loop: Optional[asyncio.Future] = None
         # We need to keep a reference to unshielded
         # task as well to prevent it from being garbage
         # collected
-        self._background_loop_unshielded = None
+        self._background_loop_unshielded: Optional[asyncio.Task] = None
         self.start_engine_loop = start_engine_loop
-        self._request_tracker: Optional[RequestTracker] = None
         self._errored_with: Optional[BaseException] = None
 
-    @classmethod
-    def from_engine_args(cls,
-                         engine_args: AsyncEngineArgs,
-                         start_engine_loop: bool = True) -> "AsyncAphrodite":
-        """Creates an async LLM engine from the engine arguments."""
-        # Create the engine configs.
-        engine_config = engine_args.create_engine_config()
+        # Lazy initialized fields
+        self._request_tracker: RequestTracker
 
-        if engine_config.device_config.device_type == "neuron":
-            raise NotImplementedError("Neuron is not supported for "
-                                      "async engine yet.")
+    @classmethod
+    def _get_executor_cls(
+            cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
+        distributed_executor_backend = (
+            engine_config.parallel_config.distributed_executor_backend)
+        if isinstance(distributed_executor_backend, type):
+            if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
+                raise TypeError(
+                    "distributed_executor_backend must be a subclass of "
+                    f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
+            if distributed_executor_backend.uses_ray:  # type: ignore
+                initialize_ray_cluster(engine_config.parallel_config)
+            executor_class = distributed_executor_backend
+        elif engine_config.device_config.device_type == "neuron":
+            from aphrodite.executor.neuron_executor import NeuronExecutorAsync
+            executor_class = NeuronExecutorAsync
+        elif engine_config.device_config.device_type == "tpu":
+            if distributed_executor_backend == "ray":
+                initialize_ray_cluster(engine_config.parallel_config)
+                from aphrodite.executor.ray_tpu_executor import (
+                    RayTPUExecutorAsync)
+                executor_class = RayTPUExecutorAsync
+            else:
+                assert distributed_executor_backend is None
+                from aphrodite.executor.tpu_executor import TPUExecutorAsync
+                executor_class = TPUExecutorAsync
         elif engine_config.device_config.device_type == "cpu":
-            from aphrodite.executor.cpu_executor import CPUExecutor
-            executor_class = CPUExecutor
-        elif engine_config.parallel_config.worker_use_ray:
+            from aphrodite.executor.cpu_executor import CPUExecutorAsync
+            executor_class = CPUExecutorAsync
+        elif engine_config.device_config.device_type == "openvino":
+            assert distributed_executor_backend is None, (
+                "Distributed execution is not supported with the OpenVINO "
+                "backend.")
+            from aphrodite.executor.openvino_executor import (
+                OpenVINOExecutorAsync)
+            executor_class = OpenVINOExecutorAsync
+        elif engine_config.device_config.device_type == "xpu":
+            if distributed_executor_backend is None:
+                from aphrodite.executor.xpu_executor import XPUExecutorAsync
+                executor_class = XPUExecutorAsync
+            elif distributed_executor_backend == "ray":
+                initialize_ray_cluster(engine_config.parallel_config)
+                from aphrodite.executor.ray_xpu_executor import (
+                    RayXPUExecutorAsync)
+                executor_class = RayXPUExecutorAsync
+            else:
+                raise RuntimeError(
+                    "Unsupported distributed executor backend for XPU.")
+        elif distributed_executor_backend == "ray":
             initialize_ray_cluster(engine_config.parallel_config)
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
             executor_class = RayGPUExecutorAsync
+        elif distributed_executor_backend == "mp":
+            from aphrodite.executor.multiproc_gpu_executor import (
+                MultiprocessingGPUExecutorAsync)
+            executor_class = MultiprocessingGPUExecutorAsync
         else:
-            assert engine_config.parallel_config.world_size == 1, (
-                "Ray is required if parallel_config.world_size > 1.")
             from aphrodite.executor.gpu_executor import GPUExecutorAsync
             executor_class = GPUExecutorAsync
+        return executor_class
+
+    @classmethod
+    def from_engine_args(
+        cls,
+        engine_args: AsyncEngineArgs,
+        start_engine_loop: bool = True,
+        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
+    ) -> "AsyncAphrodite":
+        """Creates an async LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_config = engine_args.create_engine_config()
+
+        if engine_args.engine_use_ray:
+            from aphrodite.executor import ray_utils
+            ray_utils.assert_ray_available()
+
+        executor_class = cls._get_executor_cls(engine_config)
         # Create the async LLM engine.
-        engine = cls(engine_config.parallel_config.worker_use_ray,
-                     engine_args.engine_use_ray,
-                     **engine_config.to_dict(),
-                     executor_class=executor_class,
-                     log_requests=not engine_args.disable_log_requests,
-                     log_stats=not engine_args.disable_log_stats,
-                     max_log_len=engine_args.max_log_len,
-                     start_engine_loop=start_engine_loop)
+        engine = cls(
+            executor_class.uses_ray,
+            engine_args.engine_use_ray,
+            **engine_config.to_dict(),
+            executor_class=executor_class,
+            log_requests=not engine_args.disable_log_requests,
+            log_stats=not engine_args.disable_log_stats,
+            start_engine_loop=start_engine_loop,
+            stat_loggers=stat_loggers,
+        )
         return engine
 
     @property
     def is_running(self) -> bool:
         return (self.background_loop is not None
+                and self._background_loop_unshielded is not None
                 and not self._background_loop_unshielded.done())
 
     @property
     def is_stopped(self) -> bool:
-        return self.errored or (self.background_loop is not None
+        return self.errored or (self.background_loop is not None and
+                                self._background_loop_unshielded is not None
                                 and self._background_loop_unshielded.done())
 
     @property
@@ -386,11 +502,16 @@ class AsyncAphrodite:
         self.set_errored(exc)
         self._request_tracker.propagate_exception(exc)
 
-    async def get_tokenizer(self) -> "PreTrainedTokenizer":
+    async def get_tokenizer(
+        self,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> "PreTrainedTokenizer":
         if self.engine_use_ray:
-            return await self.engine.get_tokenizer.remote()
-        else:
-            return self.engine.get_tokenizer()
+            return await self.engine.get_tokenizer.remote(  # type: ignore
+                lora_request)
+
+        return await (self.engine.get_tokenizer_group().
+                      get_lora_tokenizer_async(lora_request))
 
     def start_background_loop(self) -> None:
         """Start the background loop."""
@@ -405,8 +526,7 @@ class AsyncAphrodite:
         self._background_loop_unshielded = asyncio.get_event_loop(
         ).create_task(self.run_engine_loop())
         self._background_loop_unshielded.add_done_callback(
-            partial(_raise_exception_on_finish,
-                    error_callback=self._error_callback))
+            partial(_log_task_completion, error_callback=self._error_callback))
         self.background_loop = asyncio.shield(self._background_loop_unshielded)
 
     def _init_engine(self, *args,
@@ -420,7 +540,8 @@ class AsyncAphrodite:
             # order of the arguments.
             cache_config = kwargs["cache_config"]
             parallel_config = kwargs["parallel_config"]
-            if parallel_config.tensor_parallel_size == 1:
+            if (parallel_config.tensor_parallel_size == 1
+                    and parallel_config.pipeline_parallel_size == 1):
                 num_gpus = cache_config.gpu_memory_utilization
             else:
                 num_gpus = 1
@@ -428,7 +549,7 @@ class AsyncAphrodite:
                 self._engine_class).remote
         return engine_class(*args, **kwargs)
 
-    async def engine_step(self) -> bool:
+    async def engine_step(self, virtual_engine: int) -> bool:
         """Kick the engine to process the waiting requests.
 
         Returns True if there are in-progress requests."""
@@ -441,7 +562,8 @@ class AsyncAphrodite:
             # TODO: Maybe add add_request_batch to reduce Ray overhead
             try:
                 if self.engine_use_ray:
-                    await self.engine.add_request.remote(**new_request)
+                    await self.engine.add_request.remote(  # type: ignore
+                        **new_request)
                 else:
                     await self.engine.add_request_async(**new_request)
             except ValueError as e:
@@ -456,69 +578,101 @@ class AsyncAphrodite:
             await self._engine_abort(finished_requests)
 
         if self.engine_use_ray:
-            request_outputs = await self.engine.step.remote()
+            request_outputs = await self.engine.step.remote()  # type: ignore
         else:
-            request_outputs = await self.engine.step_async()
+            request_outputs = await self.engine.step_async(virtual_engine)
 
         # Put the outputs into the corresponding streams.
+        finished = True
         for request_output in request_outputs:
             self._request_tracker.process_request_output(
                 request_output, verbose=self.log_requests)
+            finished = finished and request_output.finished
 
-        return len(request_outputs) > 0
+        return not finished
 
     async def _engine_abort(self, request_ids: Iterable[str]):
         if self.engine_use_ray:
-            await self.engine.abort_request.remote(request_ids)
+            await self.engine.abort_request.remote(request_ids)  # type: ignore
         else:
             self.engine.abort_request(request_ids)
 
     async def run_engine_loop(self):
-        has_requests_in_progress = False
-        try:
-            while True:
-                if not has_requests_in_progress:
-                    logger.debug("Waiting for new requests...")
-                    await self._request_tracker.wait_for_new_requests()
-                    logger.debug("Got new requests!")
-
-                # Abort if iteration takes too long due to unrecoverable errors
-                # (eg. NCCL timeouts).
-                try:
-                    has_requests_in_progress = await asyncio.wait_for(
-                        self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
-                except asyncio.TimeoutError as exc:
-                    logger.error(
-                        "Engine iteration timed out. This should never happen!"
-                    )
-                    self.set_errored(exc)
-                    raise
-                await asyncio.sleep(0)
-        except KeyboardInterrupt:
-            logger.info("Engine loop interrupted. Exiting gracefully.")
+        if self.engine_use_ray:
+            pipeline_parallel_size = 1  # type: ignore
+        else:
+            pipeline_parallel_size = \
+                self.engine.parallel_config.pipeline_parallel_size
+        has_requests_in_progress = [False] * pipeline_parallel_size
+        while True:
+            if not any(has_requests_in_progress):
+                logger.debug("Waiting for new requests...")
+                # Stop the execute model loop in parallel workers until there
+                # are more requests to process. This avoids waiting
+                # indefinitely in torch.distributed ops which may otherwise
+                # timeout, and unblocks the RPC thread in the workers so that
+                # they can process any other queued control plane messages,
+                # such as add/remove lora adapters.
+                if self.engine_use_ray:
+                    await (self.engine.stop_remote_worker_execution_loop.
+                           remote()  # type: ignore
+                           )
+                else:
+                    await self.engine.stop_remote_worker_execution_loop_async()
+                await self._request_tracker.wait_for_new_requests()
+                logger.debug("Got new requests!")
+                requests_in_progress = [
+                    asyncio.create_task(self.engine_step(ve))
+                    for ve in range(pipeline_parallel_size)
+                ]
+                has_requests_in_progress = [True] * pipeline_parallel_size
+
+            # Abort if iteration takes too long due to unrecoverable errors
+            # (eg. NCCL timeouts).
+            try:
+                async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
+                    done, _ = await asyncio.wait(
+                        requests_in_progress,
+                        return_when=asyncio.FIRST_COMPLETED)
+                    for _ in range(pipeline_parallel_size):
+                        await asyncio.sleep(0)
+                for task in done:
+                    result = task.result()
+                    virtual_engine = requests_in_progress.index(task)
+                    if self.engine_use_ray:
+                        has_unfinished_requests = (
+                            await (self.engine.
+                                   has_unfinished_requests_for_virtual_engine.
+                                   remote(  # type: ignore
+                                       virtual_engine)))
+                    else:
+                        has_unfinished_requests = (
+                            self.engine.
+                            has_unfinished_requests_for_virtual_engine(
+                                virtual_engine))
+                    if result or has_unfinished_requests:
+                        requests_in_progress[virtual_engine] = (
+                            asyncio.create_task(
+                                self.engine_step(virtual_engine)))
+                        has_requests_in_progress[virtual_engine] = True
+                    else:
+                        has_requests_in_progress[virtual_engine] = False
+            except asyncio.TimeoutError as exc:
+                logger.error(
+                    "Engine iteration timed out. This should never happen!")
+                self.set_errored(exc)
+                raise
+            await asyncio.sleep(0)
 
     async def add_request(
         self,
         request_id: str,
-        prompt: Optional[str],
-        sampling_params: SamplingParams,
-        prompt_token_ids: Optional[List[int]] = None,
+        inputs: PromptInputs,
+        params: Union[SamplingParams, PoolingParams],
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> AsyncStream:
-        if self.log_requests:
-            shortened_prompt = prompt
-            shortened_token_ids = prompt_token_ids
-            if self.max_log_len is not None:
-                if shortened_prompt is not None:
-                    shortened_prompt = shortened_prompt[:self.max_log_len]
-                if shortened_token_ids is not None:
-                    shortened_token_ids = shortened_token_ids[:self.
-                                                              max_log_len]
-            logger.info(f"Received request {request_id}: "
-                        f"prompt: {shortened_prompt!r}, "
-                        f"sampling_params: {sampling_params}, "
-                        f"lora_request: {lora_request}.")
 
         if not self.is_running:
             if self.start_engine_loop:
@@ -533,36 +687,24 @@ class AsyncAphrodite:
         if arrival_time is None:
             arrival_time = time.time()
 
-        if self.engine_use_ray:
-            prompt_token_ids = await self.engine.encode_request_async.remote(
-                request_id=request_id,
-                prompt=prompt,
-                prompt_token_ids=prompt_token_ids,
-                lora_request=lora_request)
-        else:
-            prompt_token_ids = await self.engine.encode_request_async(
-                request_id=request_id,
-                prompt=prompt,
-                prompt_token_ids=prompt_token_ids,
-                lora_request=lora_request)
-
         stream = self._request_tracker.add_request(
             request_id,
-            prompt=prompt,
-            sampling_params=sampling_params,
-            prompt_token_ids=prompt_token_ids,
+            verbose=self.log_requests,
+            inputs=inputs,
+            params=params,
             arrival_time=arrival_time,
-            lora_request=lora_request)
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request)
 
         return stream
 
     async def generate(
         self,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         sampling_params: SamplingParams,
         request_id: str,
-        prompt_token_ids: Optional[List[int]] = None,
         lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> AsyncIterator[RequestOutput]:
         """Generate outputs for a request.
 
@@ -578,10 +720,12 @@ class AsyncAphrodite:
             prompt_token_ids: The token IDs of the prompt. If None, we
                 use the tokenizer to convert the prompts to token IDs.
             lora_request: LoRA request to use for generation, if any.
+            prompt_adapter_request: Prompt Adapter request to use 
+                                            for generation, if any.
 
         Yields:
-            The output `RequestOutput` objects from the AphroditeEngine for the
-            request.
+            The output `RequestOutput` objects from the AphroditeEngine
+            for the request.
 
         Details:
             - If the engine is not running, start the background loop,
@@ -627,29 +771,109 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> ...
         """
-        # Preprocess the request.
-        # This should not be used for logging, as it is monotonic time.
-        arrival_time = time.monotonic()
-
-        try:
-            stream = await self.add_request(
+        async for output in self._process_request(
                 request_id,
-                prompt,
+                inputs,
                 sampling_params,
-                prompt_token_ids=prompt_token_ids,
-                arrival_time=arrival_time,
                 lora_request=lora_request,
-            )
+                prompt_adapter_request=prompt_adapter_request,
+        ):
+            yield AphroditeEngine.validate_output(output, RequestOutput)
+
+    async def encode(
+        self,
+        inputs: PromptInputs,
+        pooling_params: PoolingParams,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> AsyncIterator[EmbeddingRequestOutput]:
+        """Generate outputs for a request from an embedding model.
+        Generate outputs for a request. This method is a coroutine. It adds the
+        request into the waiting queue of the AphroditeEngine and streams the
+        outputs from the AphroditeEngine to the caller.
+        Args:
+            prompt: The prompt string. Can be None if prompt_token_ids is
+                provided.
+            pooling_params: The pooling parameters of the request.
+            request_id: The unique id of the request.
+            prompt_token_ids: The token IDs of the prompt. If None, we
+                use the tokenizer to convert the prompts to token IDs.
+            lora_request: LoRA request to use for generation, if any.
+            multi_modal_data: Multi modal data per request.
+        Yields:
+            The output `EmbeddingRequestOutput` objects from the
+            AphroditeEngine for the request.
+        Details:
+            - If the engine is not running, start the background loop,
+              which iteratively invokes
+              :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
+              to process the waiting requests.
+            - Add the request to the engine's `RequestTracker`.
+              On the next background loop, this request will be sent to
+              the underlying engine.
+              Also, a corresponding `AsyncStream` will be created.
+            - Wait for the request outputs from `AsyncStream` and yield them.
+        Example:
+            >>> # initialize the engine and the example input
+            >>> engine = AsyncAphrodite.from_engine_args(engine_args)
+            >>> example_input = {
+            >>>     "input": "What is LLM?",
+            >>>     "request_id": 0,
+            >>> }
+            >>>
+            >>> # start the generation
+            >>> results_generator = engine.encode(
+            >>>    example_input["input"],
+            >>>    PoolingParams(),
+            >>>    example_input["request_id"])
+            >>>
+            >>> # get the results
+            >>> final_output = None
+            >>> async for request_output in results_generator:
+            >>>     if await request.is_disconnected():
+            >>>         # Abort the request if the client disconnects.
+            >>>         await engine.abort(request_id)
+            >>>         # Return or raise an error
+            >>>         ...
+            >>>     final_output = request_output
+            >>>
+            >>> # Process and return the final output
+            >>> ...
+        """
+        async for output in self._process_request(
+                request_id,
+                inputs,
+                pooling_params,
+                lora_request=lora_request,
+        ):
+            yield AphroditeEngine.validate_output(output,
+                                                  EmbeddingRequestOutput)
 
+    async def _process_request(
+        self,
+        request_id: str,
+        inputs: PromptInputs,
+        params: Union[SamplingParams, PoolingParams],
+        *,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+    ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
+        """Common logic to process requests with SamplingParams or
+        PoolingParams."""
+        arrival_time = time.time()
+
+        stream = await self.add_request(
+            request_id,
+            inputs,
+            params,
+            arrival_time=arrival_time,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request)
+
+        try:
             async for request_output in stream:
                 yield request_output
-        except asyncio.exceptions.CancelledError:
-            logger.info(f"Request {request_id} cancelled.")
-            self._abort(request_id)
-            raise
         except (Exception, asyncio.CancelledError) as e:
-            # If there is an exception or coroutine is cancelled, abort the
-            # request.
             self._abort(request_id)
             raise e
 
@@ -686,13 +910,49 @@ class AsyncAphrodite:
     async def get_model_config(self) -> ModelConfig:
         """Get the model configuration of the Aphrodite engine."""
         if self.engine_use_ray:
-            return await self.engine.get_model_config.remote()
+            return await self.engine.get_model_config.remote()  # type: ignore
         else:
             return self.engine.get_model_config()
 
-    async def do_log_stats(self) -> None:
+    async def get_parallel_config(self) -> ParallelConfig:
+        """Get the parallel configuration of the Aphrodite engine."""
+        if self.engine_use_ray:
+            return await self.engine.get_parallel_config.remote(  # type: ignore
+            )
+        else:
+            return self.engine.get_parallel_config()
+
+    async def get_decoding_config(self) -> DecodingConfig:
+        """Get the decoding configuration of the Aphrodite engine."""
+        if self.engine_use_ray:
+            return await self.engine.get_decoding_config.remote(  # type: ignore
+            )
+        else:
+            return self.engine.get_decoding_config()
+
+    async def get_scheduler_config(self) -> SchedulerConfig:
+        """Get the scheduling configuration of the Aphrodite engine."""
+        if self.engine_use_ray:
+            return await self.engine.get_scheduler_config.remote(  # type: ignore
+            )
+        else:
+            return self.engine.get_scheduler_config()
+
+    async def get_lora_config(self) -> LoRAConfig:
+        """Get the lora configuration of the Aphrodite engine."""
+        if self.engine_use_ray:
+            return await self.engine.get_lora_config.remote(  # type: ignore
+            )
+        else:
+            return self.engine.get_lora_config()
+
+    async def do_log_stats(
+            self,
+            scheduler_outputs: Optional[SchedulerOutputs] = None,
+            model_output: Optional[List[SamplerOutput]] = None) -> None:
         if self.engine_use_ray:
-            await self.engine.do_log_stats.remote()
+            await self.engine.do_log_stats.remote(  # type: ignore
+                scheduler_outputs, model_output)
         else:
             self.engine.do_log_stats()
 
@@ -705,7 +965,7 @@ class AsyncAphrodite:
 
         if self.engine_use_ray:
             try:
-                await self.engine.check_health.remote()
+                await self.engine.check_health.remote()  # type: ignore
             except ray.exceptions.RayActorError as e:
                 raise RuntimeError("Engine is dead.") from e
         else:

+ 189 - 0
aphrodite/engine/async_timeout.py

@@ -0,0 +1,189 @@
+# Workaround for https://github.com/python/cpython/issues/86296
+#
+# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
+# Licensed under the Apache License (Apache-2.0)
+
+import asyncio
+import enum
+import sys
+import warnings
+from types import TracebackType
+from typing import Any, Optional, Type
+
+if sys.version_info[:2] >= (3, 11):
+    from asyncio import timeout as asyncio_timeout
+else:
+
+    def asyncio_timeout(delay: Optional[float]) -> "Timeout":
+        """timeout context manager.
+        Useful in cases when you want to apply timeout logic around block
+        of code or in cases when asyncio.wait_for is not suitable. For example:
+        >>> async with timeout(0.001):
+        ...     async with aiohttp.get('https://github.com') as r:
+        ...         await r.text()
+        delay - value in seconds or None to disable timeout logic
+        """
+        loop = asyncio.get_running_loop()
+        deadline = loop.time() + delay if delay is not None else None
+        return Timeout(deadline, loop)
+
+    class _State(enum.Enum):
+        INIT = "INIT"
+        ENTER = "ENTER"
+        TIMEOUT = "TIMEOUT"
+        EXIT = "EXIT"
+
+    class Timeout:
+        # Internal class, please don't instantiate it directly
+        # Use timeout() and timeout_at() public factories instead.
+        #
+        # Implementation note: `async with timeout()` is preferred
+        # over `with timeout()`.
+        # While technically the Timeout class implementation
+        # doesn't need to be async at all,
+        # the `async with` statement explicitly points that
+        # the context manager should be used from async function context.
+        #
+        # This design allows to avoid many silly misusages.
+        #
+        # TimeoutError is raised immediately when scheduled
+        # if the deadline is passed.
+        # The purpose is to time out as soon as possible
+        # without waiting for the next await expression.
+
+        __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
+
+        def __init__(self, deadline: Optional[float],
+                     loop: asyncio.AbstractEventLoop) -> None:
+            self._loop = loop
+            self._state = _State.INIT
+
+            self._timeout_handler = None  # type: Optional[asyncio.Handle]
+            if deadline is None:
+                self._deadline = None  # type: Optional[float]
+            else:
+                self.update(deadline)
+
+        def __enter__(self) -> "Timeout":
+            warnings.warn(
+                "with timeout() is deprecated, use async with timeout()",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            self._do_enter()
+            return self
+
+        def __exit__(
+            self,
+            exc_type: Optional[Type[BaseException]],
+            exc_val: Optional[BaseException],
+            exc_tb: Optional[TracebackType],
+        ) -> Optional[bool]:
+            self._do_exit(exc_type)
+            return None
+
+        async def __aenter__(self) -> "Timeout":
+            self._do_enter()
+            return self
+
+        async def __aexit__(
+            self,
+            exc_type: Optional[Type[BaseException]],
+            exc_val: Optional[BaseException],
+            exc_tb: Optional[TracebackType],
+        ) -> Optional[bool]:
+            self._do_exit(exc_type)
+            return None
+
+        @property
+        def expired(self) -> bool:
+            """Is timeout expired during execution?"""
+            return self._state == _State.TIMEOUT
+
+        @property
+        def deadline(self) -> Optional[float]:
+            return self._deadline
+
+        def reject(self) -> None:
+            """Reject scheduled timeout if any."""
+            # cancel is maybe better name but
+            # task.cancel() raises CancelledError in asyncio world.
+            if self._state not in (_State.INIT, _State.ENTER):
+                raise RuntimeError(f"invalid state {self._state.value}")
+            self._reject()
+
+        def _reject(self) -> None:
+            if self._timeout_handler is not None:
+                self._timeout_handler.cancel()
+                self._timeout_handler = None
+
+        def shift(self, delay: float) -> None:
+            """Advance timeout on delay seconds.
+            The delay can be negative.
+            Raise RuntimeError if shift is called when deadline is not scheduled
+            """
+            deadline = self._deadline
+            if deadline is None:
+                raise RuntimeError(
+                    "cannot shift timeout if deadline is not scheduled")
+            self.update(deadline + delay)
+
+        def update(self, deadline: float) -> None:
+            """Set deadline to absolute value.
+            deadline argument points on the time in the same clock system
+            as loop.time().
+            If new deadline is in the past the timeout is raised immediately.
+            Please note: it is not POSIX time but a time with
+            undefined starting base, e.g. the time of the system power on.
+            """
+            if self._state == _State.EXIT:
+                raise RuntimeError(
+                    "cannot reschedule after exit from context manager")
+            if self._state == _State.TIMEOUT:
+                raise RuntimeError("cannot reschedule expired timeout")
+            if self._timeout_handler is not None:
+                self._timeout_handler.cancel()
+            self._deadline = deadline
+            if self._state != _State.INIT:
+                self._reschedule()
+
+        def _reschedule(self) -> None:
+            assert self._state == _State.ENTER
+            deadline = self._deadline
+            if deadline is None:
+                return
+
+            now = self._loop.time()
+            if self._timeout_handler is not None:
+                self._timeout_handler.cancel()
+
+            task = asyncio.current_task()
+            if deadline <= now:
+                self._timeout_handler = self._loop.call_soon(
+                    self._on_timeout, task)
+            else:
+                self._timeout_handler = self._loop.call_at(
+                    deadline, self._on_timeout, task)
+
+        def _do_enter(self) -> None:
+            if self._state != _State.INIT:
+                raise RuntimeError(f"invalid state {self._state.value}")
+            self._state = _State.ENTER
+            self._reschedule()
+
+        def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
+            if exc_type is asyncio.CancelledError and \
+                    self._state == _State.TIMEOUT:
+                self._timeout_handler = None
+                raise asyncio.TimeoutError
+            # timeout has not expired
+            self._state = _State.EXIT
+            self._reject()
+            return None
+
+        def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
+            if task:
+                task.cancel()
+            self._state = _State.TIMEOUT
+            # drop the reference early
+            self._timeout_handler = None

+ 476 - 178
aphrodite/engine/metrics.py

@@ -1,227 +1,523 @@
 import time
+from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING
+from typing import Counter as CollectionsCounter
+from typing import Dict, List, Optional, Protocol, Union
 
 import numpy as np
+import prometheus_client
 from loguru import logger
-from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
-                               disable_created_metrics)
+
+from aphrodite.executor.ray_utils import ray
+
+if ray is not None:
+    from ray.util import metrics as ray_metrics
+else:
+    ray_metrics = None
 
 if TYPE_CHECKING:
     from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
-disable_created_metrics()
-
-# The begin-* and end* here are used by the documentation generator
-# to extract the metrics definitions.
+prometheus_client.disable_created_metrics()
 
 
 # begin-metrics-definitions
 class Metrics:
+    labelname_finish_reason = "finished_reason"
+    _gauge_cls = prometheus_client.Gauge
+    _counter_cls = prometheus_client.Counter
+    _histogram_cls = prometheus_client.Histogram
 
-    def __init__(self, labelnames: List[str]):
+    def __init__(self, labelnames: List[str], max_model_len: int):
         # Unregister any existing Aphrodite collectors
-        for collector in list(REGISTRY._collector_to_names):
-            if hasattr(collector, "_name") and "aphrodite" in collector._name:
-                REGISTRY.unregister(collector)
+        self._unregister_aphrodite_metrics()
 
         # Config Information
-        self.info_cache_config = Info(
-            name="aphrodite:cache_config",
-            documentation="information of cache_config",
-        )
+        self._create_info_cache_config()
 
         # System stats
-        self.gauge_scheduler_running = Gauge(
+        #   Scheduler State
+        self.gauge_scheduler_running = self._gauge_cls(
             name="aphrodite:num_requests_running",
             documentation="Number of requests currently running on GPU.",
-            labelnames=labelnames,
-        )
-        self.gauge_scheduler_swapped = Gauge(
-            name="aphrodite:num_requests_swapped",
-            documentation="Number of requests swapped to CPU.",
-            labelnames=labelnames,
-        )
-        self.gauge_scheduler_waiting = Gauge(
+            labelnames=labelnames)
+        self.gauge_scheduler_waiting = self._gauge_cls(
             name="aphrodite:num_requests_waiting",
             documentation="Number of requests waiting to be processed.",
-            labelnames=labelnames,
-        )
-        self.gauge_gpu_cache_usage = Gauge(
+            labelnames=labelnames)
+        self.gauge_scheduler_swapped = self._gauge_cls(
+            name="aphrodite:num_requests_swapped",
+            documentation="Number of requests swapped to CPU.",
+            labelnames=labelnames)
+        #   KV Cache Usage in %
+        self.gauge_gpu_cache_usage = self._gauge_cls(
             name="aphrodite:gpu_cache_usage_perc",
             documentation="GPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames,
-        )
-        self.gauge_cpu_cache_usage = Gauge(
+            labelnames=labelnames)
+        self.gauge_cpu_cache_usage = self._gauge_cls(
             name="aphrodite:cpu_cache_usage_perc",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames,
-        )
-
-        # Raw stats from last model iteration
-        self.counter_prompt_tokens = Counter(
+            labelnames=labelnames)
+
+        # Iteration stats
+        self.counter_num_preemption = self._counter_cls(
+            name="aphrodite:num_preemptions_total",
+            documentation="Cumulative number of preemption from the engine.",
+            labelnames=labelnames)
+        self.counter_prompt_tokens = self._counter_cls(
             name="aphrodite:prompt_tokens_total",
             documentation="Number of prefill tokens processed.",
-            labelnames=labelnames,
-        )
-        self.counter_generation_tokens = Counter(
+            labelnames=labelnames)
+        self.counter_generation_tokens = self._counter_cls(
             name="aphrodite:generation_tokens_total",
             documentation="Number of generation tokens processed.",
-            labelnames=labelnames,
-        )
-        self.histogram_time_to_first_token = Histogram(
+            labelnames=labelnames)
+        self.histogram_time_to_first_token = self._histogram_cls(
             name="aphrodite:time_to_first_token_seconds",
             documentation="Histogram of time to first token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.001,
-                0.005,
-                0.01,
-                0.02,
-                0.04,
-                0.06,
-                0.08,
-                0.1,
-                0.25,
-                0.5,
-                0.75,
-                1.0,
-                2.5,
-                5.0,
-                7.5,
-                10.0,
-            ],
-        )
-        self.histogram_time_per_output_token = Histogram(
+                0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
+                0.75, 1.0, 2.5, 5.0, 7.5, 10.0
+            ])
+        self.histogram_time_per_output_token = self._histogram_cls(
             name="aphrodite:time_per_output_token_seconds",
             documentation="Histogram of time per output token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.01,
-                0.025,
-                0.05,
-                0.075,
-                0.1,
-                0.15,
-                0.2,
-                0.3,
-                0.4,
-                0.5,
-                0.75,
-                1.0,
-                2.5,
-            ],
-        )
-        self.histogram_e2e_request_latency = Histogram(
+                0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
+                1.0, 2.5
+            ])
+
+        # Request stats
+        #   Latency
+        self.histogram_e2e_time_request = self._histogram_cls(
             name="aphrodite:e2e_request_latency_seconds",
             documentation="Histogram of end to end request latency in seconds.",
             labelnames=labelnames,
-            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
+            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
+        #   Metadata
+        self.histogram_num_prompt_tokens_request = self._histogram_cls(
+            name="aphrodite:request_prompt_tokens",
+            documentation="Number of prefill tokens processed.",
+            labelnames=labelnames,
+            buckets=build_1_2_5_buckets(max_model_len),
         )
-
-        # Legacy metrics
-        self.gauge_avg_prompt_throughput = Gauge(
+        self.histogram_num_generation_tokens_request = \
+            self._histogram_cls(
+                name="aphrodite:request_generation_tokens",
+                documentation="Number of generation tokens processed.",
+                labelnames=labelnames,
+                buckets=build_1_2_5_buckets(max_model_len),
+            )
+        self.histogram_best_of_request = self._histogram_cls(
+            name="aphrodite:request_params_best_of",
+            documentation="Histogram of the best_of request parameter.",
+            labelnames=labelnames,
+            buckets=[1, 2, 5, 10, 20],
+        )
+        self.histogram_n_request = self._histogram_cls(
+            name="aphrodite:request_params_n",
+            documentation="Histogram of the n request parameter.",
+            labelnames=labelnames,
+            buckets=[1, 2, 5, 10, 20],
+        )
+        self.counter_request_success = self._counter_cls(
+            name="aphrodite:request_success_total",
+            documentation="Count of successfully processed requests.",
+            labelnames=labelnames + [Metrics.labelname_finish_reason])
+
+        # Speculatie decoding stats
+        self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
+            name="aphrodite:spec_decode_draft_acceptance_rate",
+            documentation="Speulative token acceptance rate.",
+            labelnames=labelnames)
+        self.gauge_spec_decode_efficiency = self._gauge_cls(
+            name="aphrodite:spec_decode_efficiency",
+            documentation="Speculative decoding system efficiency.",
+            labelnames=labelnames)
+        self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
+            name="aphrodite:spec_decode_num_accepted_tokens_total",
+            documentation="Number of accepted tokens.",
+            labelnames=labelnames))
+        self.counter_spec_decode_num_draft_tokens = self._counter_cls(
+            name="aphrodite:spec_decode_num_draft_tokens_total",
+            documentation="Number of draft tokens.",
+            labelnames=labelnames)
+        self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
+            name="aphrodite:spec_decode_num_emitted_tokens_total",
+            documentation="Number of emitted tokens.",
+            labelnames=labelnames))
+
+        # Deprecated in favor of aphrodite:prompt_tokens_total
+        self.gauge_avg_prompt_throughput = self._gauge_cls(
             name="aphrodite:avg_prompt_throughput_toks_per_s",
             documentation="Average prefill throughput in tokens/s.",
             labelnames=labelnames,
         )
-        self.gauge_avg_generation_throughput = Gauge(
+        # Deprecated in favor of aphrodite:generation_tokens_total
+        self.gauge_avg_generation_throughput = self._gauge_cls(
             name="aphrodite:avg_generation_throughput_toks_per_s",
             documentation="Average generation throughput in tokens/s.",
             labelnames=labelnames,
         )
 
+    def _create_info_cache_config(self) -> None:
+        # Config Information
+        self.info_cache_config = prometheus_client.Info(
+            name='aphrodite:cache_config',
+            documentation='information of cache_config')
+
+    def _unregister_aphrodite_metrics(self) -> None:
+        for collector in list(prometheus_client.REGISTRY._collector_to_names):
+            if hasattr(collector, "_name") and "aphrodite" in collector._name:
+                prometheus_client.REGISTRY.unregister(collector)
+
 
 # end-metrics-definitions
 
 
+class _RayGaugeWrapper:
+    """Wraps around ray.util.metrics.Gauge to provide same API as
+    prometheus_client.Gauge"""
+
+    def __init__(self,
+                 name: str,
+                 documentation: str = "",
+                 labelnames: Optional[List[str]] = None):
+        labelnames_tuple = tuple(labelnames) if labelnames else None
+        self._gauge = ray_metrics.Gauge(name=name,
+                                        description=documentation,
+                                        tag_keys=labelnames_tuple)
+
+    def labels(self, **labels):
+        self._gauge.set_default_tags(labels)
+        return self
+
+    def set(self, value: Union[int, float]):
+        return self._gauge.set(value)
+
+
+class _RayCounterWrapper:
+    """Wraps around ray.util.metrics.Counter to provide same API as
+    prometheus_client.Counter"""
+
+    def __init__(self,
+                 name: str,
+                 documentation: str = "",
+                 labelnames: Optional[List[str]] = None):
+        labelnames_tuple = tuple(labelnames) if labelnames else None
+        self._counter = ray_metrics.Counter(name=name,
+                                            description=documentation,
+                                            tag_keys=labelnames_tuple)
+
+    def labels(self, **labels):
+        self._counter.set_default_tags(labels)
+        return self
+
+    def inc(self, value: Union[int, float] = 1.0):
+        if value == 0:
+            return
+        return self._counter.inc(value)
+
+
+class _RayHistogramWrapper:
+    """Wraps around ray.util.metrics.Histogram to provide same API as
+    prometheus_client.Histogram"""
+
+    def __init__(self,
+                 name: str,
+                 documentation: str = "",
+                 labelnames: Optional[List[str]] = None,
+                 buckets: Optional[List[float]] = None):
+        labelnames_tuple = tuple(labelnames) if labelnames else None
+        self._histogram = ray_metrics.Histogram(name=name,
+                                                description=documentation,
+                                                tag_keys=labelnames_tuple,
+                                                boundaries=buckets)
+
+    def labels(self, **labels):
+        self._histogram.set_default_tags(labels)
+        return self
+
+    def observe(self, value: Union[int, float]):
+        return self._histogram.observe(value)
+
+
+class RayMetrics(Metrics):
+    """
+    RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
+    Provides the same metrics as Metrics but uses Ray's util.metrics library.
+    """
+    _gauge_cls = _RayGaugeWrapper
+    _counter_cls = _RayCounterWrapper
+    _histogram_cls = _RayHistogramWrapper
+
+    def __init__(self, labelnames: List[str], max_model_len: int):
+        if ray_metrics is None:
+            raise ImportError("RayMetrics requires Ray to be installed.")
+        super().__init__(labelnames, max_model_len)
+
+    def _unregister_aphrodite_metrics(self) -> None:
+        # No-op on purpose
+        pass
+
+    def _create_info_cache_config(self) -> None:
+        # No-op on purpose
+        pass
+
+
+def build_1_2_5_buckets(max_value: int) -> List[int]:
+    """
+    Builds a list of buckets with increasing powers of 10 multiplied by 
+    mantissa values (1, 2, 5) until the value exceeds the specified maximum.
+
+    Example:
+    >>> build_1_2_5_buckets(100)
+    [1, 2, 5, 10, 20, 50, 100]
+    """
+    mantissa_lst = [1, 2, 5]
+    exponent = 0
+    buckets: List[int] = []
+    while True:
+        for m in mantissa_lst:
+            value = m * 10**exponent
+            if value <= max_value:
+                buckets.append(value)
+            else:
+                return buckets
+        exponent += 1
+
+
 @dataclass
 class Stats:
     """Created by AphroditeEngine for use by StatLogger."""
-
     now: float
 
-    # System stats.
-    num_running: int
-    num_waiting: int
-    num_swapped: int
-    gpu_cache_usage: float
-    cpu_cache_usage: float
-
-    # Raw stats from last model iteration.
-    num_prompt_tokens: int
-    num_generation_tokens: int
-    time_to_first_tokens: List[float]
-    time_per_output_tokens: List[float]
+    # System stats (should have _sys suffix)
+    #   Scheduler State
+    num_running_sys: int
+    num_waiting_sys: int
+    num_swapped_sys: int
+    #   KV Cache Usage in %
+    gpu_cache_usage_sys: float
+    cpu_cache_usage_sys: float
+
+    # Iteration stats (should have _iter suffix)
+    num_prompt_tokens_iter: int
+    num_generation_tokens_iter: int
+    time_to_first_tokens_iter: List[float]
+    time_per_output_tokens_iter: List[float]
+    num_preemption_iter: int
+
+    # Request stats (should have _requests suffix)
+    #   Latency
     time_e2e_requests: List[float]
+    #   Metadata
+    num_prompt_tokens_requests: List[int]
+    num_generation_tokens_requests: List[int]
+    best_of_requests: List[int]
+    n_requests: List[int]
+    finished_reason_requests: List[str]
 
     spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
 
 
-class StatLogger:
-    """StatLogger is used AphroditeEngine to log to Promethus and Stdout."""
+class SupportsMetricsInfo(Protocol):
 
-    def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
-        # Metadata for logging locally.
-        self.last_local_log = time.monotonic()
-        self.local_interval = local_interval
+    def metrics_info(self) -> Dict[str, str]:
+        ...
+
+
+def local_interval_elapsed(now: float, last_log: float,
+                           local_interval: float) -> bool:
+    elapsed_time = now - last_log
+    return elapsed_time > local_interval
+
+
+def get_throughput(tracked_stats: List[int], now: float,
+                   last_log: float) -> float:
+    return float(np.sum(tracked_stats) / (now - last_log))
 
+
+class StatLoggerBase(ABC):
+    """Base class for StatLogger."""
+
+    def __init__(self, local_interval: float) -> None:
         # Tracked stats over current local logging interval.
         self.num_prompt_tokens: List[int] = []
         self.num_generation_tokens: List[int] = []
+        self.last_local_log = time.time()
+        self.local_interval = local_interval
+        self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
+    @abstractmethod
+    def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def log(self, stats: Stats) -> None:
+        raise NotImplementedError
 
+    def maybe_update_spec_decode_metrics(self, stats: Stats):
+        """Save spec decode metrics (since they are unlikely
+        to be emitted at same time as log interval)."""
+        if stats.spec_decode_metrics is not None:
+            self.spec_decode_metrics = stats.spec_decode_metrics
+
+
+class LoggingStatLogger(StatLoggerBase):
+    """LoggingStatLogger is used in AphroditeEngine to log to Stdout."""
+
+    def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+        raise NotImplementedError
+
+    def log(self, stats: Stats) -> None:
+        """Called by AphroditeEngine.
+           Logs to Stdout every self.local_interval seconds."""
+
+        # Save tracked stats for token counters.
+        self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
+        self.num_generation_tokens.append(stats.num_generation_tokens_iter)
+
+        # Update spec decode metrics
+        self.maybe_update_spec_decode_metrics(stats)
+
+        # Log locally every local_interval seconds.
+        if local_interval_elapsed(stats.now, self.last_local_log,
+                                  self.local_interval):
+            # Compute summary metrics for tracked stats (and log them
+            # to promethus if applicable).
+            prompt_throughput = get_throughput(self.num_prompt_tokens,
+                                               now=stats.now,
+                                               last_log=self.last_local_log)
+            generation_throughput = get_throughput(
+                self.num_generation_tokens,
+                now=stats.now,
+                last_log=self.last_local_log)
+
+            # Log to stdout.
+            logger.info(
+                "Avg prompt throughput: "
+                f"{prompt_throughput:.1f} tokens/s, "
+                "Avg generation throughput: "
+                f"{generation_throughput:.1f} tokens/s, "
+                f"Running: {stats.num_running_sys} reqs, "
+                f"Swapped: {stats.num_swapped_sys} reqs, "
+                f"Pending: {stats.num_waiting_sys} reqs, "
+                f"GPU KV cache usage: {stats.gpu_cache_usage_sys * 100:.1f}%, "
+                f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%.")
+
+            if self.spec_decode_metrics is not None:
+                logger.info(
+                    self._format_spec_decode_metrics_str(
+                        self.spec_decode_metrics))
+
+            # Reset tracked stats for next interval.
+            self.num_prompt_tokens = []
+            self.num_generation_tokens = []
+            self.last_local_log = stats.now
+            self.spec_decode_metrics = None
+
+    def _format_spec_decode_metrics_str(
+            self, metrics: "SpecDecodeWorkerMetrics") -> str:
+
+        return ("Speculative metrics: "
+                f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
+                f"System efficiency: {metrics.system_efficiency:.3f}, "
+                f"Number of speculative tokens: {metrics.num_spec_tokens}, "
+                f"Number of accepted tokens: {metrics.accepted_tokens}, "
+                f"Number of draft tokens: {metrics.draft_tokens}, "
+                f"Number of emitted tokens: {metrics.emitted_tokens}.")
+
+
+class PrometheusStatLogger(StatLoggerBase):
+    """PrometheusStatLogger is used AphroditeEngine to log to Promethus."""
+    _metrics_cls = Metrics
+
+    def __init__(self, local_interval: float, labels: Dict[str, str],
+                 max_model_len: int) -> None:
+        super().__init__(local_interval)
         # Prometheus metrics
         self.labels = labels
-        self.metrics = Metrics(labelnames=list(labels.keys()))
+        self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
+                                         max_model_len=max_model_len)
 
-    def info(self, type: str, obj: object) -> None:
+    def info(self, type: str, obj: SupportsMetricsInfo) -> None:
         if type == "cache_config":
             self.metrics.info_cache_config.info(obj.metrics_info())
 
-    def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
-        return float(np.sum(tracked_stats) / (now - self.last_local_log))
+    def _log_gauge(self, gauge, data: Union[int, float]) -> None:
+        # Convenience function for logging to gauge.
+        gauge.labels(**self.labels).set(data)
+
+    def _log_counter(self, counter, data: Union[int, float]) -> None:
+        # Convenience function for logging to counter.
+        counter.labels(**self.labels).inc(data)
 
-    def _local_interval_elapsed(self, now: float) -> bool:
-        elapsed_time = now - self.last_local_log
-        return elapsed_time > self.local_interval
+    def _log_counter_labels(self, counter, data: CollectionsCounter,
+                            label_key: str) -> None:
+        # Convenience function for collection counter of labels.
+        for label, count in data.items():
+            counter.labels(**{**self.labels, label_key: label}).inc(count)
+
+    def _log_histogram(self, histogram, data: Union[List[int],
+                                                    List[float]]) -> None:
+        # Convenience function for logging list to histogram.
+        for datum in data:
+            histogram.labels(**self.labels).observe(datum)
 
     def _log_prometheus(self, stats: Stats) -> None:
-        # Set system stat gauges.
-        self.metrics.gauge_scheduler_running.labels(**self.labels).set(
-            stats.num_running)
-        self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
-            stats.num_swapped)
-        self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
-            stats.num_waiting)
-        self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
-            stats.gpu_cache_usage)
-        self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
-            stats.cpu_cache_usage)
-
-        # Add to token counters.
-        self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
-            stats.num_prompt_tokens)
-        self.metrics.counter_generation_tokens.labels(**self.labels).inc(
-            stats.num_generation_tokens)
-
-        # Observe request level latencies in histograms.
-        for ttft in stats.time_to_first_tokens:
-            self.metrics.histogram_time_to_first_token.labels(
-                **self.labels).observe(ttft)
-        for tpot in stats.time_per_output_tokens:
-            self.metrics.histogram_time_per_output_token.labels(
-                **self.labels).observe(tpot)
-        for e2e in stats.time_e2e_requests:
-            self.metrics.histogram_e2e_request_latency.labels(
-                **self.labels).observe(e2e)
+        # System state data
+        self._log_gauge(self.metrics.gauge_scheduler_running,
+                        stats.num_running_sys)
+        self._log_gauge(self.metrics.gauge_scheduler_swapped,
+                        stats.num_swapped_sys)
+        self._log_gauge(self.metrics.gauge_scheduler_waiting,
+                        stats.num_waiting_sys)
+        self._log_gauge(self.metrics.gauge_gpu_cache_usage,
+                        stats.gpu_cache_usage_sys)
+        self._log_gauge(self.metrics.gauge_cpu_cache_usage,
+                        stats.cpu_cache_usage_sys)
+
+        # Iteration level data
+        self._log_counter(self.metrics.counter_num_preemption,
+                          stats.num_preemption_iter)
+        self._log_counter(self.metrics.counter_prompt_tokens,
+                          stats.num_prompt_tokens_iter)
+        self._log_counter(self.metrics.counter_generation_tokens,
+                          stats.num_generation_tokens_iter)
+        self._log_histogram(self.metrics.histogram_time_to_first_token,
+                            stats.time_to_first_tokens_iter)
+        self._log_histogram(self.metrics.histogram_time_per_output_token,
+                            stats.time_per_output_tokens_iter)
+
+        # Request level data
+        # Latency
+        self._log_histogram(self.metrics.histogram_e2e_time_request,
+                            stats.time_e2e_requests)
+        # Metadata
+        finished_reason_counter = CollectionsCounter(
+            stats.finished_reason_requests)
+        self._log_counter_labels(self.metrics.counter_request_success,
+                                 finished_reason_counter,
+                                 Metrics.labelname_finish_reason)
+        self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
+                            stats.num_prompt_tokens_requests)
+        self._log_histogram(
+            self.metrics.histogram_num_generation_tokens_request,
+            stats.num_generation_tokens_requests)
+        self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
+        self._log_histogram(self.metrics.histogram_best_of_request,
+                            stats.best_of_requests)
 
     def _log_prometheus_interval(self, prompt_throughput: float,
                                  generation_throughput: float) -> None:
         # Logs metrics to prometheus that are computed every logging_interval.
         # Support legacy gauge metrics that make throughput calculations on
-        # the Aphrodite side.
-        # Moving forward, we should use counters like counter_prompt_tokens,
-        # counter_generation_tokens
+        # the Aphrodite side. Moving forward, we should use counters like
+        # counter_prompt_tokens, counter_generation_tokens
         # Which log raw data and calculate summaries using rate() on the
         # grafana/prometheus side.
         self.metrics.gauge_avg_prompt_throughput.labels(
@@ -229,59 +525,61 @@ class StatLogger:
         self.metrics.gauge_avg_generation_throughput.labels(
             **self.labels).set(generation_throughput)
 
-    def log(self, stats: Stats) -> None:
-        """Called by AphroditeEngine.
-        Logs to prometheus and tracked stats every iteration.
-        Logs to Stdout every self.local_interval seconds."""
-
+    def log(self, stats: Stats):
+        """Logs to prometheus and tracked stats every iteration."""
         # Log to prometheus.
         self._log_prometheus(stats)
 
         # Save tracked stats for token counters.
-        self.num_prompt_tokens.append(stats.num_prompt_tokens)
-        self.num_generation_tokens.append(stats.num_generation_tokens)
+        self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
+        self.num_generation_tokens.append(stats.num_generation_tokens_iter)
+
+        # Update spec decode metrics
+        self.maybe_update_spec_decode_metrics(stats)
 
         # Log locally every local_interval seconds.
-        if self._local_interval_elapsed(stats.now):
-            # Compute summary metrics for tracked stats (and log them to
-            # prometheus if applicable).
-            prompt_throughput = self._get_throughput(self.num_prompt_tokens,
-                                                     now=stats.now)
-            generation_throughput = self._get_throughput(
-                self.num_generation_tokens, now=stats.now)
+        if local_interval_elapsed(stats.now, self.last_local_log,
+                                  self.local_interval):
+            # Compute summary metrics for tracked stats (and log them
+            # to promethus if applicable).
+            prompt_throughput = get_throughput(self.num_prompt_tokens,
+                                               now=stats.now,
+                                               last_log=self.last_local_log)
+            generation_throughput = get_throughput(
+                self.num_generation_tokens,
+                now=stats.now,
+                last_log=self.last_local_log)
+
             self._log_prometheus_interval(
                 prompt_throughput=prompt_throughput,
-                generation_throughput=generation_throughput,
-            )
-
-            # Log to stdout.
-            logger.info(
-                f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
-                f"Avg generation throughput: {generation_throughput:.1f} "
-                "tokens/s, "
-                f"Running: {stats.num_running} reqs, "
-                f"Swapped: {stats.num_swapped} reqs, "
-                f"Pending: {stats.num_waiting} reqs, "
-                f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, "
-                f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%")
+                generation_throughput=generation_throughput)
+
+            if self.spec_decode_metrics is not None:
+                self._log_gauge(
+                    self.metrics.gauge_spec_decode_draft_acceptance_rate,
+                    self.spec_decode_metrics.draft_acceptance_rate)
+                self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
+                                self.spec_decode_metrics.system_efficiency)
+                self._log_counter(
+                    self.metrics.counter_spec_decode_num_accepted_tokens,
+                    self.spec_decode_metrics.accepted_tokens)
+                self._log_counter(
+                    self.metrics.counter_spec_decode_num_draft_tokens,
+                    self.spec_decode_metrics.draft_tokens)
+                self._log_counter(
+                    self.metrics.counter_spec_decode_num_emitted_tokens,
+                    self.spec_decode_metrics.emitted_tokens)
 
             # Reset tracked stats for next interval.
             self.num_prompt_tokens = []
             self.num_generation_tokens = []
             self.last_local_log = stats.now
+            self.spec_decode_metrics = None
 
-            if stats.spec_decode_metrics is not None:
-                logger.info(
-                    self._format_spec_decode_metrics_str(
-                        stats.spec_decode_metrics))
 
-    def _format_spec_decode_metrics_str(
-            self, metrics: "SpecDecodeWorkerMetrics") -> str:
+class RayPrometheusStatLogger(PrometheusStatLogger):
+    """RayPrometheusStatLogger uses Ray metrics instead."""
+    _metrics_cls = RayMetrics
 
-        return ("Speculative metrics: "
-                f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
-                f"System efficiency: {metrics.system_efficiency:.3f}, "
-                f"Number of speculative tokens: {metrics.num_spec_tokens}, "
-                f"Number of accepted tokens: {metrics.accepted_tokens}, "
-                f"Number of draft tokens tokens: {metrics.draft_tokens}, "
-                f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
+    def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+        return None

+ 21 - 12
aphrodite/engine/output_processor/interfaces.py

@@ -1,11 +1,12 @@
 from abc import ABC, abstractmethod
-from typing import Callable, Iterable, List
+from typing import Callable, List
 
 from transformers import PreTrainedTokenizer
 
 from aphrodite.common.config import SchedulerConfig
 from aphrodite.common.sequence import (Sequence, SequenceGroup,
                                        SequenceGroupOutput)
+from aphrodite.common.utils import Counter
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -15,30 +16,32 @@ class SequenceGroupOutputProcessor(ABC):
     """Interface for logic that processes new token ids in sequence groups,
     managing detokenization, stop checking, and freeing/forking sequences with
     the scheduler.
-    This is highly coupled with the LLMEngine and should be seen as an extension
-    of it. The logic is separated to simplify the LLMEngine class and allow
-    separate implementations for single-step decoding (which supports beam
-    search sequence forking) and multi-step decoding (which does not support
-    beam search, but does support speculative decoding).
+
+    This is highly coupled with the AphroditeEngine and should be seen as an
+    extension of it. The logic is separated to simplify the AphroditeEngine
+    class and allow separate implementations for single-step decoding
+    (which supports beam search sequence forking) and multi-step decoding
+    (which does not support beam search, but does support speculative decoding)
     """
 
     @staticmethod
     def create_output_processor(
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        scheduler: List[Scheduler],
+        seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: "StopChecker",
     ):
         """Create an output processor.
+
         This returns a single-step output processor if num_lookahead_slots is
         zero, else returns a multi-step output processor.
         """
         if scheduler_config.num_lookahead_slots == 0:
             # Importing here to avoid cycle.
-            from aphrodite.engine.output_processor.single_step import \
-                SingleStepOutputProcessor
+            from aphrodite.engine.output_processor.single_step import (
+                SingleStepOutputProcessor)
             return SingleStepOutputProcessor(
                 scheduler_config,
                 detokenizer,
@@ -48,8 +51,8 @@ class SequenceGroupOutputProcessor(ABC):
             )
         else:
             # Importing here to avoid cycle.
-            from aphrodite.engine.output_processor.multi_step import \
-                MultiStepOutputProcessor
+            from aphrodite.engine.output_processor.multi_step import (
+                MultiStepOutputProcessor)
             return MultiStepOutputProcessor(
                 detokenizer,
                 scheduler,
@@ -66,3 +69,9 @@ class SequenceGroupOutputProcessor(ABC):
         scheduler.
         """
         pass
+
+    @abstractmethod
+    def process_prompt_logprob(self, seq_group: SequenceGroup,
+                               outputs: List[SequenceGroupOutput]) -> None:
+        """Update prompt logprobs received from outputs to seq_group."""
+        pass

+ 38 - 17
aphrodite/engine/output_processor/multi_step.py

@@ -1,15 +1,18 @@
-from typing import Callable, Iterable, List
+import functools
+from typing import Callable, List
 
 from transformers import PreTrainedTokenizer
 
-from aphrodite.processing.scheduler import Scheduler
-from aphrodite.engine.output_processor.interfaces import (
-    SequenceGroupOutputProcessor)
-from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.common.logger import log_once
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (Logprob, Sequence, SequenceGroup,
+from aphrodite.common.sequence import (Sequence, SequenceGroup,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceStatus)
+from aphrodite.common.utils import Counter
+from aphrodite.engine.output_processor.interfaces import (
+    SequenceGroupOutputProcessor)
+from aphrodite.engine.output_processor.stop_checker import StopChecker
+from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
 
 
@@ -17,10 +20,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
     """SequenceGroupOutputProcessor which handles logic related to
     detokenization and stopping conditions. It specializes to "multi-step
     decoding", where Aphrodite's worker may generate multiple tokens per
-    invocation.
-    This is currently mutually exclusive with advanced sampling techniques like
-    beam search, which motivates the separation of this logic from the single
-    step output processor.
+    invocation. This is currently mutually exclusive with advanced sampling
+    techniques like beam search, which motivates the separation of this logic
+    from the single step output processor.
+
     This class is responsible for things such as correctly appending all new
     token ids to their sequence, detokenizing new token ids, truncating new
     output tokens after an eos token, and correctly handling the case where the
@@ -30,8 +33,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
     def __init__(
         self,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        scheduler: List[Scheduler],
+        seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: StopChecker,
     ):
@@ -41,11 +44,27 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
         self.get_tokenizer_for_seq = get_tokenizer_for_seq
         self.stop_checker = stop_checker
 
+    def process_prompt_logprob(self, seq_group: SequenceGroup,
+                               outputs: List[SequenceGroupOutput]) -> None:
+        # TODO: Prompt logprob currently not implemented in multi step
+        # workers.
+        self._log_prompt_logprob_unsupported_warning_once()
+
+    @staticmethod
+    @functools.lru_cache()
+    def _log_prompt_logprob_unsupported_warning_once():
+        log_once(
+            level="WARNING",
+            message="Prompt logprob is not supported by multi step workers. "
+            "(e.g., speculative decode uses multi step workers).")
+
     def process_outputs(self, sequence_group: SequenceGroup,
                         outputs: List[SequenceGroupOutput]) -> None:
         """Append new tokens in the outputs to sequences in the sequence group.
+
         This only supports sequence groups of size 1. It supports greater than
         one new token per sequence.
+
         This applies logic like stop condition checking and detokenization,
         including freeing finished sequences. It also handles cases where there
         are tokens emitted after the EOS token.
@@ -59,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
 
         # Since there's only one sequence per sequence group, we can take the
         # first sample.
-        samples = [outputs[step].samples[0] for step in range(len(outputs))]
+        samples = [output.samples[0] for output in outputs]
 
         # -1 means the output token is not valid (eg. due to spec decode
         # rejecting tokens).
@@ -75,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
                              valid_samples: List[SequenceOutput],
                              sampling_params: SamplingParams) -> None:
         output_token_ids = [sample.output_token for sample in valid_samples]
+        output_logprobs = [sample.logprobs for sample in valid_samples]
 
         # Truncate to max_tokens if necessary.
         remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
@@ -99,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
 
         # Incrementally append tokens to the sequence, as if we had only one new
         # token.
-        for output_token_id in output_token_ids:
+        for output_token_id, output_logprob in zip(output_token_ids,
+                                                   output_logprobs):
             seq.append_token_id(
                 token_id=output_token_id,
-                # TODO emit logprobs in multi-step decoding.
-                logprobs={output_token_id: Logprob(0.0)},
+                logprobs=output_logprob,
             )
 
             new_char_count = 0
@@ -119,4 +139,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
                 break
 
         if seq.is_finished():
-            self.scheduler.free_seq(seq)
+            for scheduler in self.scheduler:
+                scheduler.free_seq(seq)

+ 84 - 31
aphrodite/engine/output_processor/single_step.py

@@ -1,12 +1,13 @@
-from typing import Iterable, List, Tuple, Union
+from typing import Dict, List, Tuple, Union
 
 from aphrodite.common.config import SchedulerConfig
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (Sequence, SequenceGroup,
                                        SequenceGroupOutput, SequenceOutput,
                                        SequenceStatus)
-from aphrodite.engine.output_processor.interfaces import \
-    SequenceGroupOutputProcessor
+from aphrodite.common.utils import Counter
+from aphrodite.engine.output_processor.interfaces import (
+    SequenceGroupOutputProcessor)
 from aphrodite.engine.output_processor.stop_checker import StopChecker
 from aphrodite.processing.scheduler import Scheduler
 from aphrodite.transformers_utils.detokenizer import Detokenizer
@@ -18,6 +19,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
     scheduling of the next batch. Output processing logic includes
     detokenization, and determining if a sequence is finished (e.g. via max len
     or eos token).
+
     The SingleStepOutputProcessor is specialized to the case where the model
     emits at most a single token per invocation, which precludes configurations
     such as speculative decoding or multi-step decoding. This enables beam
@@ -29,8 +31,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
         self,
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
-        seq_counter: Iterable[int],
+        scheduler: List[Scheduler],
+        seq_counter: Counter,
         stop_checker: StopChecker,
     ):
         self.scheduler_config = scheduler_config
@@ -43,6 +45,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                         outputs: List[SequenceGroupOutput]) -> None:
         """Append all new tokens to sequences in the sequence group. Fork any
         surviving beam candidates; free any unsurviving ones.
+
         Invokes detokenizer to detokenize new tokens, and also marks sequences
         as finished if they meet stop conditions.
         """
@@ -50,26 +53,66 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 ), f"{type(self)} does not support multiple outputs per step"
         return self._process_sequence_group_outputs(sequence_group, outputs[0])
 
-    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
-                                        outputs: SequenceGroupOutput) -> None:
+    def process_prompt_logprob(self, seq_group: SequenceGroup,
+                               outputs: List[SequenceGroupOutput]) -> None:
+        assert len(outputs) == 1, ("Single step should only has 1 output.")
+        output = outputs[0]
+        prompt_logprobs = output.prompt_logprobs
 
-        # Process prompt logprobs
-        prompt_logprobs = outputs.prompt_logprobs
-        if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
-            self.detokenizer.decode_prompt_logprobs_inplace(
-                seq_group, prompt_logprobs)
-            seq_group.prompt_logprobs = prompt_logprobs
+        # If this is the first (or only) "chunk" of the prefill, we need
+        # to prepend None to the list of prompt logprobs. The reason for this
+        # is that for N prompt tokens, the Sampler will generate N-1 total
+        # prompt logprobs during prefill since the token at idx 0 will not
+        # have a logprob associated with it.
+        if prompt_logprobs is not None:
+            if not seq_group.prompt_logprobs:
+                prompt_logprobs = [None] + prompt_logprobs
+                seq_group.prompt_logprobs = []
+            if seq_group.sampling_params.detokenize and self.detokenizer:
+                self.detokenizer.decode_prompt_logprobs_inplace(
+                    seq_group,
+                    prompt_logprobs,
+                    position_offset=len(seq_group.prompt_logprobs))
+            seq_group.prompt_logprobs.extend(prompt_logprobs)
 
+    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
+                                        outputs: SequenceGroupOutput) -> None:
+        sampling_params = seq_group.sampling_params
+        if sampling_params.n == 1 and not sampling_params.use_beam_search:
+            # only have one output sample
+            sample = outputs.samples[0]
+            # only have one sequence
+            seq = seq_group.seqs[0]
+            seq.append_token_id(sample.output_token, sample.logprobs)
+            if sampling_params.detokenize and self.detokenizer:
+                new_char_count = self.detokenizer.decode_sequence_inplace(
+                    seq, sampling_params)
+            else:
+                new_char_count = 0
+            self.stop_checker.maybe_stop_sequence(
+                seq,
+                new_char_count,
+                sampling_params,
+                lora_req=seq_group.lora_request,
+            )
+            if seq.is_finished():
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(seq)
+            return
         # Process samples
         samples = outputs.samples
         parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
         existing_finished_seqs = seq_group.get_finished_seqs()
-        parent_child_dict = {
+        parent_child_dict: Dict[int, List[SequenceOutput]] = {
             parent_seq.seq_id: []
             for parent_seq in parent_seqs
         }
         for sample in samples:
-            parent_child_dict[sample.parent_seq_id].append(sample)
+            # Guard against a KeyError which can occur if the request was
+            # aborted while the output was generated
+            if (child_list :=
+                    parent_child_dict.get(sample.parent_seq_id)) is not None:
+                child_list.append(sample)
         # List of (child, parent)
         child_seqs: List[Tuple[Sequence, Sequence]] = []
 
@@ -83,11 +126,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 # not be used in the future iterations.
                 parent.status = SequenceStatus.FINISHED_ABORTED
                 seq_group.remove(parent.seq_id)
-                self.scheduler.free_seq(parent)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(parent)
                 continue
             # Fork the parent sequence if there are multiple child samples.
             for child_sample in child_samples[:-1]:
-                new_child_seq_id = next(self.seq_counter)
+                new_child_seq_id: int = next(self.seq_counter)
                 child = parent.fork(new_child_seq_id)
                 child.append_token_id(child_sample.output_token,
                                       child_sample.logprobs)
@@ -101,23 +145,28 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             child_seqs.append((parent, parent))
 
         for seq, _ in child_seqs:
-            if seq_group.sampling_params.detokenize:
+            if sampling_params.detokenize and self.detokenizer:
                 new_char_count = self.detokenizer.decode_sequence_inplace(
-                    seq, seq_group.sampling_params)
+                    seq, sampling_params)
             else:
                 new_char_count = 0
-            self.stop_checker.maybe_stop_sequence(seq, new_char_count,
-                                                  seq_group.sampling_params)
+            self.stop_checker.maybe_stop_sequence(
+                seq,
+                new_char_count,
+                sampling_params,
+                lora_req=seq_group.lora_request,
+            )
 
         # Non-beam search case
-        if not seq_group.sampling_params.use_beam_search:
+        if not sampling_params.use_beam_search:
             # For newly created child sequences, add them to the sequence group
             # and fork them in block manager if they are not finished.
             for seq, parent in child_seqs:
                 if seq is not parent:
                     seq_group.add(seq)
                     if not seq.is_finished():
-                        self.scheduler.fork_seq(parent, seq)
+                        for scheduler in self.scheduler:
+                            scheduler.fork_seq(parent, seq)
 
             # Free the finished and selected parent sequences' memory in block
             # manager. Keep them in the sequence group as candidate output.
@@ -125,15 +174,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             # old sequences.
             for seq, parent in child_seqs:
                 if seq is parent and seq.is_finished():
-                    self.scheduler.free_seq(seq)
+                    for scheduler in self.scheduler:
+                        scheduler.free_seq(seq)
             return
 
         # Beam search case
         # Select the child sequences to keep in the sequence group.
         selected_child_seqs = []
         unselected_child_seqs = []
-        beam_width = seq_group.sampling_params.best_of
-        length_penalty = seq_group.sampling_params.length_penalty
+        beam_width = sampling_params.best_of
+        length_penalty = sampling_params.length_penalty
 
         # Select the newly finished sequences with the highest scores
         # to replace existing finished sequences.
@@ -187,8 +237,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             best_running_seq = running_child_seqs[0][0]
             current_worst_seq = all_finished_seqs[beam_width - 1][0]
             stop_beam_search = self._check_beam_search_early_stopping(
-                seq_group.sampling_params.early_stopping,
-                seq_group.sampling_params, best_running_seq, current_worst_seq)
+                sampling_params.early_stopping, sampling_params,
+                best_running_seq, current_worst_seq)
 
         if stop_beam_search:
             # Stop the beam search and remove all the running sequences from
@@ -210,13 +260,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             if seq is not parent:
                 seq_group.add(seq)
                 if not seq.is_finished():
-                    self.scheduler.fork_seq(parent, seq)
+                    for scheduler in self.scheduler:
+                        scheduler.fork_seq(parent, seq)
 
         # Free the finished and selected parent sequences' memory in block
         # manager. Keep them in the sequence group as candidate output.
         for seq, parent in selected_child_seqs:
             if seq is parent and seq.is_finished():
-                self.scheduler.free_seq(seq)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(seq)
 
         # Remove the unselected parent sequences from the sequence group and
         # free their memory in block manager.
@@ -225,7 +277,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 # Remove the parent sequence if it is not selected for next
                 # iteration
                 seq_group.remove(seq.seq_id)
-                self.scheduler.free_seq(seq)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(seq)
 
     def _check_beam_search_early_stopping(
         self,

+ 24 - 4
aphrodite/engine/output_processor/stop_checker.py

@@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer
 
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import Sequence, SequenceStatus
+from aphrodite.lora.request import LoRARequest
 
 
 class StopChecker:
@@ -16,12 +17,25 @@ class StopChecker:
     def __init__(self, max_model_len: int,
                  get_tokenizer_for_seq: Callable[[Sequence],
                                                  PreTrainedTokenizer]):
-        self.max_model_len = max_model_len
+        # Do not use it directly, but use `self._get_max_model_len`.
+        self._max_model_len = max_model_len
         self.get_tokenizer_for_seq = get_tokenizer_for_seq
 
-    def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
-                            sampling_params: SamplingParams) -> None:
+    def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
+        if lora_req and lora_req.long_lora_max_len:
+            return lora_req.long_lora_max_len
+        else:
+            return self._max_model_len
+
+    def maybe_stop_sequence(
+        self,
+        seq: Sequence,
+        new_char_count: int,
+        sampling_params: SamplingParams,
+        lora_req: Optional[LoRARequest] = None,
+    ) -> None:
         """Stop the finished sequences.
+
        new_char_count is the number of chars added to the
            sequence's output text for the newly generated token
         """
@@ -34,6 +48,11 @@ class StopChecker:
         # Check if the sequence has generated the EOS token.
         if ((not sampling_params.ignore_eos)
                 and seq.get_last_token_id() == seq.eos_token_id):
+            # Remove the last EOS token unless explicitly specified
+            # This prevents unintended exposure of the EOS token
+            if new_char_count and (
+                    not sampling_params.include_stop_str_in_output):
+                seq.output_text = seq.output_text[:-new_char_count]
             seq.status = SequenceStatus.FINISHED_STOPPED
             return
 
@@ -58,7 +77,7 @@ class StopChecker:
             return
 
         # Check if the sequence has reached max_model_len.
-        if seq.get_len() > self.max_model_len:
+        if seq.get_len() > self._get_max_model_len(lora_req):
             seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
             return
 
@@ -72,6 +91,7 @@ class StopChecker:
                             sampling_params: SamplingParams) -> Optional[str]:
         """Check if any stop strings are matched and truncate sequence
         output text accordingly.
+
         Returns the stop string if matched or else None.
         """
         if not new_char_count:

+ 11 - 5
aphrodite/engine/output_processor/util.py

@@ -1,15 +1,21 @@
 from typing import List
+from typing import Sequence as GenericSequence
+from typing import Union
 
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
+                                       SequenceGroupOutput)
 
 
-def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
-                                    num_seq_groups: int):
+def create_output_by_sequence_group(
+        outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
+        num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
     """Helper method which transforms a 2d list organized by
     [step][sequence group] into [sequence group][step].
     """
-    output_by_sequence_group = [[] for _ in range(num_seq_groups)]
-    for step in sampler_outputs:
+    output_by_sequence_group: List[List[SequenceGroupOutput]] = [
+        [] for _ in range(num_seq_groups)
+    ]
+    for step in outputs:
         for i, sequence_group_output in enumerate(step):
             output_by_sequence_group[i].append(sequence_group_output)
 

+ 83 - 0
aphrodite/engine/protocol.py

@@ -0,0 +1,83 @@
+from typing import AsyncIterator, List, Optional, Protocol, runtime_checkable
+
+from transformers import PreTrainedTokenizer
+
+from aphrodite.common.config import DecodingConfig, ModelConfig
+from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.inputs.data import PromptInputs
+from aphrodite.lora.request import LoRARequest
+from aphrodite.processing.scheduler import SchedulerOutputs
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+
+
+@runtime_checkable
+class AsyncEngineClient(Protocol):
+    """Protocol class for Clients to AsyncAphrodite"""
+
+    @property
+    def is_running(self) -> bool:
+        ...
+
+    @property
+    def is_stopped(self) -> bool:
+        ...
+
+    @property
+    def errored(self) -> bool:
+        ...
+
+    async def generate(
+        self,
+        inputs: PromptInputs,
+        sampling_params: SamplingParams,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None
+    ) -> AsyncIterator[RequestOutput]:
+        """Generates outputs for a request"""
+        ...
+
+    async def encode(
+        self,
+        inputs: PromptInputs,
+        pooling_params: PoolingParams,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> AsyncIterator[EmbeddingRequestOutput]:
+        """Generate outputs for a request from an embedding model."""
+        ...
+
+    async def abort(self, request_id: str) -> None:
+        """Abort a request.
+        Args:
+            request_id: The unique id of the request.
+        """
+        ...
+
+    async def get_model_config(self) -> ModelConfig:
+        """Get the model configuration of the Aphrodite engine."""
+        ...
+
+    async def get_decoding_config(self) -> DecodingConfig:
+        """Get the decoding configuration of the Aphrodite engine."""
+        ...
+
+    async def get_tokenizer(
+        self,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> PreTrainedTokenizer:
+        """Get the appropriate Tokenizer for the request"""
+        ...
+
+    async def do_log_stats(
+        self,
+        scheduler_outputs: Optional[SchedulerOutputs] = None,
+        model_output: Optional[List[SamplerOutput]] = None,
+    ) -> None:
+        pass
+
+    async def check_health(self) -> None:
+        """Raise if unhealthy"""

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff