Bladeren bron

chore: reformats (#90)

AlpinDale 1 jaar geleden
bovenliggende
commit
efc6f7fbec
44 gewijzigde bestanden met toevoegingen van 408 en 289 verwijderingen
  1. 1 0
      aphrodite/common/block.py
  2. 10 8
      aphrodite/common/config.py
  3. 5 3
      aphrodite/common/logits_processor.py
  4. 3 4
      aphrodite/common/outputs.py
  5. 10 5
      aphrodite/common/sampling_params.py
  6. 3 1
      aphrodite/common/sequence.py
  7. 2 1
      aphrodite/common/utils.py
  8. 59 26
      aphrodite/endpoints/api_server_kobold.py
  9. 14 8
      aphrodite/endpoints/api_server_ooba.py
  10. 11 5
      aphrodite/endpoints/protocol.py
  11. 7 6
      aphrodite/engine/aphrodite_engine.py
  12. 6 6
      aphrodite/engine/args_tools.py
  13. 2 1
      aphrodite/engine/async_aphrodite.py
  14. 5 5
      aphrodite/modeling/hf_downloader.py
  15. 7 3
      aphrodite/modeling/layers/activation.py
  16. 5 4
      aphrodite/modeling/layers/layernorm.py
  17. 4 4
      aphrodite/modeling/layers/quantized_linear/__init__.py
  18. 3 2
      aphrodite/modeling/layers/quantized_linear/awq.py
  19. 2 2
      aphrodite/modeling/layers/quantized_linear/gptq.py
  20. 7 6
      aphrodite/modeling/layers/quantized_linear/utils.py
  21. 28 21
      aphrodite/modeling/layers/rotary_embedding.py
  22. 71 56
      aphrodite/modeling/layers/sampler.py
  23. 6 6
      aphrodite/modeling/loader.py
  24. 1 2
      aphrodite/modeling/megatron/parallel_state.py
  25. 3 2
      aphrodite/modeling/metadata.py
  26. 1 2
      aphrodite/modeling/models/__init__.py
  27. 2 3
      aphrodite/modeling/models/gpt_j.py
  28. 3 3
      aphrodite/modeling/models/gpt_neox.py
  29. 5 6
      aphrodite/modeling/models/llama.py
  30. 3 3
      aphrodite/modeling/models/mistral.py
  31. 2 2
      aphrodite/modeling/quantization_utils/awq.py
  32. 3 3
      aphrodite/modeling/quantization_utils/base.py
  33. 14 13
      aphrodite/modeling/quantization_utils/gptq.py
  34. 1 0
      aphrodite/modeling/utils.py
  35. 4 1
      aphrodite/processing/policy.py
  36. 2 2
      aphrodite/processing/scheduler.py
  37. 7 4
      aphrodite/task_handler/worker.py
  38. 3 1
      aphrodite/transformers_utils/config.py
  39. 9 8
      tests/kernels/conftest.py
  40. 4 2
      tests/kernels/test_activation.py
  41. 7 3
      tests/models/test_models.py
  42. 17 11
      tests/samplers/test_samplers.py
  43. 40 31
      tests/serving.py
  44. 6 4
      tests/throughput.py

+ 1 - 0
aphrodite/common/block.py

@@ -49,6 +49,7 @@ class LogicalTokenBlock:
 
 class PhysicalTokenBlock:
     """Represents the state of a block in the KV cache."""
+
     def __init__(
         self,
         device: Device,

+ 10 - 8
aphrodite/common/config.py

@@ -261,11 +261,11 @@ class SchedulerConfig:
     """
 
     def __init__(
-            self,
-            max_num_batched_tokens: Optional[int],
-            max_num_seqs: int,
-            max_model_len: int,
-            max_paddings: int,
+        self,
+        max_num_batched_tokens: Optional[int],
+        max_num_seqs: int,
+        max_model_len: int,
+        max_paddings: int,
     ) -> None:
         if max_num_batched_tokens is not None:
             self.max_num_batched_tokens = max_num_batched_tokens
@@ -288,7 +288,8 @@ class SchedulerConfig:
         if self.max_num_batched_tokens < self.max_num_seqs:
             raise ValueError(
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
-                f"be greater than or equal to max_num_seqs ({self.max_num_seqs}).")
+                f"be greater than or equal to max_num_seqs ({self.max_num_seqs})."
+            )
 
 
 _STR_DTYPE_TO_TORCH_DTYPE = {
@@ -358,7 +359,8 @@ def _get_and_verify_max_len(
     if derived_max_model_len == float('inf'):
         raise ValueError(
             "The model's config.json must contain one of the following keys "
-            f"to determine the original maximum length of the model: {possible_keys}")
+            f"to determine the original maximum length of the model: {possible_keys}"
+        )
 
     rope_scaling = getattr(hf_config, "rope_scaling", None)
     if rope_scaling is not None:
@@ -375,4 +377,4 @@ def _get_and_verify_max_len(
             " in model's config.json). This may lead to incorrect model "
             "outputs or CUDA errors. Make sure the value is correct and "
             "within the model context size.")
-    return int(max_model_len)
+    return int(max_model_len)

+ 5 - 3
aphrodite/common/logits_processor.py

@@ -6,7 +6,8 @@ from typing import Dict
 class LogitsProcessor(ABC):
 
     @abstractmethod
-    def __call__(self, logits: torch.Tensor, output_tokens: list[list[int]]) -> None:
+    def __call__(self, logits: torch.Tensor,
+                 output_tokens: list[list[int]]) -> None:
         """Logits are edited in-place"""
         pass
 
@@ -42,14 +43,15 @@ class BiasLogitsProcessor(LogitsProcessor):
                                      1 / (1 - (values / 100)))
         logits[0, keys] *= update_factors
 
-        
+
 class BanEOSUntil(LogitsProcessor):
     """Bans the EOS token until a certain condition is met.
     In this case, 'number of output tokens'.
 
     With this condition, both 'min_tokens' and 'ignore_eos'
     parameters can be handled gracefully."""
-    def __init__(self, min_tokens:int, eos_token_id:int):
+
+    def __init__(self, min_tokens: int, eos_token_id: int):
         self._min_tokens = min_tokens
         self._eos_token_id = eos_token_id
 

+ 3 - 4
aphrodite/common/outputs.py

@@ -1,8 +1,7 @@
 from typing import List, Optional
 
-from aphrodite.common.sequence import (
-    PromptLogprobs, SampleLogprobs, SequenceGroup,
-    SequenceStatus)
+from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
+                                       SequenceGroup, SequenceStatus)
 
 
 class CompletionOutput:
@@ -117,4 +116,4 @@ class RequestOutput:
                 f"prompt_token_ids={self.prompt_token_ids}, "
                 f"prompt_logprobs={self.prompt_logprobs}, "
                 f"outputs={self.outputs}, "
-                f"finished={self.finished})")
+                f"finished={self.finished})")

+ 10 - 5
aphrodite/common/sampling_params.py

@@ -194,11 +194,15 @@ class SamplingParams:
         if not 0.0 < self.tfs <= 1.0:
             raise ValueError(f"tfs must be in (0, 1], got {self.tfs}.")
         if not 0.0 <= self.epsilon_cutoff <= 1000.0:
-            raise ValueError(f"epsilon_cutoff must be in [0, 1000], got {self.epsilon_cutoff}.")
+            raise ValueError(
+                f"epsilon_cutoff must be in [0, 1000], got {self.epsilon_cutoff}."
+            )
         if not self.eta_cutoff >= 0:
-            raise ValueError(f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
+            raise ValueError(
+                f"eta_cutoff must be non negative, got {self.eta_cutoff}.")
         if not 0.0 <= self.typical_p <= 1.0:
-            raise ValueError(f"typical_p must be in (0, 1], got {self.typical_p}.")
+            raise ValueError(
+                f"typical_p must be in (0, 1], got {self.typical_p}.")
         if self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -207,7 +211,8 @@ class SamplingParams:
                 f"logprobs must be non-negative, got {self.logprobs}.")
         if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
             raise ValueError(
-                f"prompt_logprobs must be non-negative, got {self.prompt_logprobs}.")
+                f"prompt_logprobs must be non-negative, got {self.prompt_logprobs}."
+            )
 
     def _verify_beam_search(self) -> None:
         if self.best_of == 1:
@@ -274,4 +279,4 @@ class SamplingParams:
                 f"custom_token_bans={self.custom_token_bans}, "
                 f"logprobs={self.logprobs}, "
                 f"prompt_logprobs={self.prompt_logprobs}, "
-                f"skip_special_tokens={self.skip_special_tokens})")
+                f"skip_special_tokens={self.skip_special_tokens})")

+ 3 - 1
aphrodite/common/sequence.py

@@ -382,6 +382,7 @@ class SequenceOutputs:
                 and self.output_token == other.output_token
                 and self.logprobs == other.logprobs)
 
+
 class SequenceGroupOutputs:
     """The model outputs associated with a sequence group."""
 
@@ -403,6 +404,7 @@ class SequenceGroupOutputs:
         return (self.samples == other.samples
                 and self.prompt_logprobs == other.prompt_logprobs)
 
+
 # For each sequence group, we generate a list of SequenceOutputs object,
 # each of which contains one possible candidate for the next token.
-SamplerOutput = List[SequenceGroupOutputs]
+SamplerOutput = List[SequenceGroupOutputs]

+ 2 - 1
aphrodite/common/utils.py

@@ -8,6 +8,7 @@ import torch
 
 from aphrodite import cuda_utils
 
+
 class Device(enum.Enum):
     GPU = enum.auto()
     CPU = enum.auto()
@@ -52,4 +53,4 @@ def random_uuid() -> str:
 
 def in_wsl() -> bool:
     # Reference: https://github.com/microsoft/WSL/issues/4071
-    return "microsoft" in " ".join(uname()).lower()
+    return "microsoft" in " ".join(uname()).lower()

+ 59 - 26
aphrodite/endpoints/api_server_kobold.py

@@ -30,15 +30,17 @@ engine: AsyncAphrodite = None
 
 badwordsids: List[int] = []
 
+
 def _set_badwords(tokenizer, hf_config):
     global badwordsids
     if hf_config.bad_words_ids is not None:
         badwordsids = hf_config.bad_words_ids
         return
-    
-    badwordsids = [ v for k, v in tokenizer.get_vocab().items()
-                    if any(c in str(k) for c in "[]")
-                  ]
+
+    badwordsids = [
+        v for k, v in tokenizer.get_vocab().items()
+        if any(c in str(k) for c in "[]")
+    ]
     if tokenizer.pad_token_id in badwordsids:
         badwordsids.remove(tokenizer.pad_token_id)
     badwordsids.append(tokenizer.eos_token_id)
@@ -57,22 +59,30 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
-    return JSONResponse({"msg": message, "type": "invalid_request_error"},
+
+def create_error_response(status_code: HTTPStatus,
+                          message: str) -> JSONResponse:
+    return JSONResponse({
+        "msg": message,
+        "type": "invalid_request_error"
+    },
                         status_code=status_code.value)
 
+
 @app.exception_handler(ValueError)
 def validation_exception_handler(request, exc):  # pylint: disable=unused-argument
     return create_error_response(HTTPStatus.UNPROCESSABLE_ENTITY, str(exc))
 
-def prepare_engine_payload(kai_payload: KAIGenerationInputSchema) -> Tuple[SamplingParams, List[int]]:
+
+def prepare_engine_payload(
+        kai_payload: KAIGenerationInputSchema
+) -> Tuple[SamplingParams, List[int]]:
     """Create SamplingParams and truncated input tokens for AsyncEngine"""
 
     if kai_payload.max_context_length > max_model_len:
         raise ValueError(
             f"max_context_length ({kai_payload.max_context_length}) must be less than or equal to "
-            f"max_model_len ({max_model_len})"
-        )
+            f"max_model_len ({max_model_len})")
 
     sampling_params = SamplingParams(max_tokens=kai_payload.max_length)
 
@@ -86,7 +96,6 @@ def prepare_engine_payload(kai_payload: KAIGenerationInputSchema) -> Tuple[Sampl
         kai_payload.top_p = 1.0
         kai_payload.top_k = -1
 
-
     sampling_params = SamplingParams(
         n=kai_payload.n,
         best_of=kai_payload.n,
@@ -100,38 +109,47 @@ def prepare_engine_payload(kai_payload: KAIGenerationInputSchema) -> Tuple[Sampl
         eta_cutoff=kai_payload.eta_cutoff,
         epsilon_cutoff=kai_payload.eps_cutoff,
         stop=kai_payload.stop_sequence,
-        custom_token_bans=badwordsids if kai_payload.use_default_badwordsids else [],
+        custom_token_bans=badwordsids
+        if kai_payload.use_default_badwordsids else [],
         max_tokens=kai_payload.max_length,
     )
 
-    max_input_tokens = max(1, kai_payload.max_context_length - kai_payload.max_length)
+    max_input_tokens = max(
+        1, kai_payload.max_context_length - kai_payload.max_length)
     input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]
 
     return sampling_params, input_tokens
 
+
 @kai_api.post("/generate")
 async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
     """ Generate text """
 
     req_id = f"kai-{random_uuid()}"
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    result_generator = engine.generate(None, sampling_params, req_id, input_tokens)
+    result_generator = engine.generate(None, sampling_params, req_id,
+                                       input_tokens)
 
     final_res: RequestOutput = None
     async for res in result_generator:
         final_res = res
     assert final_res is not None
 
-    return JSONResponse({"results": [{"text": output.text} for output in final_res.outputs]})
+    return JSONResponse(
+        {"results": [{
+            "text": output.text
+        } for output in final_res.outputs]})
 
 
 @extra_api.post("/generate/stream")
-async def generate_stream(kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
+async def generate_stream(
+        kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
     """ Generate text SSE streaming """
 
     req_id = f"kai-{random_uuid()}"
     sampling_params, input_tokens = prepare_engine_payload(kai_payload)
-    results_generator = engine.generate(None, sampling_params, req_id, input_tokens)
+    results_generator = engine.generate(None, sampling_params, req_id,
+                                        input_tokens)
 
     async def stream_kobold() -> AsyncGenerator[bytes, None]:
         previous_output = ""
@@ -142,44 +160,55 @@ async def generate_stream(kai_payload: KAIGenerationInputSchema) -> StreamingRes
             yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()
 
     return StreamingResponse(stream_kobold(),
-                             headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
+                             headers={
+                                 "Cache-Control": "no-cache",
+                                 "Connection": "keep-alive"
+                             },
                              media_type='text/event-stream')
 
+
 @extra_api.post("/generate/check")
 async def check_generation():
     """ stub for compatibility """
     return JSONResponse({"results": [{"text": ""}]})
 
+
 @kai_api.get("/info/version")
 async def get_version():
     """ Impersonate KAI """
     return JSONResponse({"result": "1.2.4"})
 
+
 @kai_api.get("/model")
 async def get_model():
     """ Get current model """
     return JSONResponse({"result": f"aphrodite/{served_model}"})
 
+
 @kai_api.get("/config/soft_prompts_list")
 async def get_available_softprompts():
     """ stub for compatibility """
-    return JSONResponse({"values":[]})
+    return JSONResponse({"values": []})
+
 
 @kai_api.get("/config/soft_prompt")
 async def get_current_softprompt():
     """ stub for compatibility """
     return JSONResponse({"value": ""})
 
+
 @kai_api.put("/config/soft_prompt")
 async def set_current_softprompt():
     """ stub for compatibility """
     return JSONResponse({})
 
+
 @app.get("/api/latest/config/max_context_length")
 async def get_max_context_length() -> JSONResponse:
     """Return the max context length based on the EngineArgs configuration."""
     max_context_length = engine_model_config.max_model_len
-    return JSONResponse({"value": max_context_length })
+    return JSONResponse({"value": max_context_length})
+
 
 @app.get("/api/latest/config/max_length")
 async def get_max_length() -> JSONResponse:
@@ -187,22 +216,25 @@ async def get_max_length() -> JSONResponse:
     max_length = args.max_length
     return JSONResponse({"value": max_length})
 
+
 @extra_api.post("/abort")
 async def abort_generation():
     """ stub for compatibility """
     return JSONResponse({})
 
+
 @extra_api.get("/version")
 async def get_extra_version():
     """ Impersonate KoboldCpp with streaming support """
     return JSONResponse({"result": "KoboldCpp", "version": "1.30"})
 
+
 @app.get("/")
 async def get_kobold_lite_ui():
     """Serves a cached copy of the Kobold Lite UI, loading it from disk on demand if needed."""
     #read and return embedded kobold lite
     global kobold_lite_ui
-    if kobold_lite_ui=="":
+    if kobold_lite_ui == "":
         scriptpath = os.path.dirname(os.path.abspath(__file__))
         klitepath = os.path.join(scriptpath, "klite.embd")
         if os.path.exists(klitepath):
@@ -212,6 +244,7 @@ async def get_kobold_lite_ui():
             print("Embedded Kobold Lite not found")
     return HTMLResponse(content=kobold_lite_ui)
 
+
 app.include_router(kai_api, prefix="/api/v1")
 app.include_router(kai_api, prefix="/api/latest", include_in_schema=False)
 app.include_router(extra_api, prefix="/api/extra")
@@ -231,10 +264,10 @@ if __name__ == "__main__":
                         "specified, the model name will be the same as "
                         "the huggingface name.")
     parser.add_argument("--max-length",
-                    type=int,
-                    default=256,
-                    help="The maximum length of the generated text. "
-                    "For use with Kobold Horde.")
+                        type=int,
+                        default=256,
+                        help="The maximum length of the generated text. "
+                        "For use with Kobold Horde.")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     global args
@@ -256,11 +289,11 @@ if __name__ == "__main__":
     tokenizer = get_tokenizer(engine_args.tokenizer,
                               tokenizer_mode=engine_args.tokenizer_mode,
                               trust_remote_code=engine_args.trust_remote_code)
-    
+
     _set_badwords(tokenizer, engine_model_config.hf_config)
 
     uvicorn.run(app,
                 host=args.host,
                 port=args.port,
                 log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
+                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

+ 14 - 8
aphrodite/endpoints/api_server_ooba.py

@@ -31,8 +31,10 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
 @app.post("/api/v1/generate")
-async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
+async def generate(
+    request: Request, x_api_key: str = Header(None)) -> Response:
     """Generate completion for the request.
 
     The request should be a JSON object with the following fields:
@@ -41,12 +43,13 @@ async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
     - other fields: the sampling parameters (See `SamplingParams` for details).
     """
     if x_api_key is None or x_api_key != valid_api_key:
-        raise HTTPException(status_code=401, detail="Unauthorized. Please acquire an API key.")
+        raise HTTPException(status_code=401,
+                            detail="Unauthorized. Please acquire an API key.")
 
     request_dict = await request.json()
     prompt = request_dict.pop("prompt")
     stream = request_dict.pop("stream", False)
-    
+
     if 'stopping_strings' in request_dict:
         request_dict['stop'] = request_dict.pop('stopping_strings')
     if 'max_new_tokens' in request_dict:
@@ -61,11 +64,14 @@ async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
     request_dict['logits_processors'] = []
 
     min_length = request_dict.pop('min_tokens', 0)
-    if request_dict.get('ignore_eos', False):  # ignore_eos/ban_eos_token is functionally equivalent to `min_tokens = max_tokens`
+    if request_dict.get(
+            'ignore_eos', False
+    ):  # ignore_eos/ban_eos_token is functionally equivalent to `min_tokens = max_tokens`
         min_length = request_dict.get('max_tokens', 16)
 
     if min_length:
-        request_dict['logits_processors'].append(BanEOSUntil(min_length, engine.engine.tokenizer.eos_token_id))
+        request_dict['logits_processors'].append(
+            BanEOSUntil(min_length, engine.engine.tokenizer.eos_token_id))
 
     sampling_params = SamplingParams()
     for key, value in request_dict.items():
@@ -80,9 +86,9 @@ async def generate(request: Request, x_api_key: str = Header(None)) -> Response:
     async def stream_results() -> AsyncGenerator[bytes, None]:
         async for request_output in results_generator:
             prompt = request_output.prompt
-            text_outputs = [
-                {"text": output.text} for output in request_output.outputs
-            ]
+            text_outputs = [{
+                "text": output.text
+            } for output in request_output.outputs]
             ret = {"results": text_outputs}
             yield (json.dumps(ret) + "\n\n").encode("utf-8")
 

+ 11 - 5
aphrodite/endpoints/protocol.py

@@ -1,6 +1,7 @@
 from typing import List, Optional, Union
 from pydantic import BaseModel, Field, root_validator, conint, confloat, conlist, NonNegativeFloat, NonNegativeInt, PositiveInt
 
+
 class SamplingParams(BaseModel):
     n: int = Field(1, alias="n")
     best_of: Optional[int] = Field(None, alias="best_of")
@@ -20,16 +21,19 @@ class SamplingParams(BaseModel):
     ignore_eos: bool = Field(False, alias="ignore_eos")
     max_tokens: int = Field(16, alias="max_length")
     logprobs: Optional[int] = Field(None, alias="logprobs")
-    custom_token_bans: Optional[List[int]] = Field(None, alias="custom_token_bans")
+    custom_token_bans: Optional[List[int]] = Field(None,
+                                                   alias="custom_token_bans")
 
     @root_validator
     def validate_best_of(cls, values):
         best_of = values.get("best_of")
         n = values.get("n")
         if best_of is not None and (best_of <= 0 or best_of > n):
-            raise ValueError("best_of must be a positive integer less than or equal to n")
+            raise ValueError(
+                "best_of must be a positive integer less than or equal to n")
         return values
 
+
 class KAIGenerationInputSchema(BaseModel):
     prompt: str
     n: Optional[conint(ge=1, le=5)] = 1
@@ -42,7 +46,7 @@ class KAIGenerationInputSchema(BaseModel):
     top_a: Optional[NonNegativeFloat] = 0.0
     top_p: Optional[confloat(ge=0, le=1)] = 1.0
     tfs: Optional[confloat(ge=0, le=1)] = 1.0
-    eps_cutoff: Optional[confloat(ge=0,le=1000)] = 0.0
+    eps_cutoff: Optional[confloat(ge=0, le=1000)] = 0.0
     eta_cutoff: Optional[NonNegativeFloat] = 0.0
     typical: Optional[confloat(ge=0, le=1)] = 1.0
     temperature: Optional[NonNegativeFloat] = 1.0
@@ -67,5 +71,7 @@ class KAIGenerationInputSchema(BaseModel):
 
     @root_validator
     def check_context(cls, values):
-        assert values.get("max_length") <= values.get("max_context_length"), f"max_length must not be larger than max_context_length"
-        return values
+        assert values.get("max_length") <= values.get(
+            "max_context_length"
+        ), f"max_length must not be larger than max_context_length"
+        return values

+ 7 - 6
aphrodite/engine/aphrodite_engine.py

@@ -4,18 +4,19 @@ from functools import partial
 from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                         SchedulerConfig)
+                                     SchedulerConfig)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.ray_tools import RayWorker, initialize_cluster, ray
 from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (
-    SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata,
-    SequenceGroupOutputs, SequenceOutputs, SequenceStatus)
+from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
+                                       SequenceGroupMetadata,
+                                       SequenceGroupOutputs, SequenceOutputs,
+                                       SequenceStatus)
 from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
-                                               get_tokenizer)
+                                                    get_tokenizer)
 from aphrodite.common.utils import Counter
 
 if ray:
@@ -707,4 +708,4 @@ class AphroditeEngine:
         output = all_outputs[0]
         for other_output in all_outputs[1:]:
             assert output == other_output
-        return output
+        return output

+ 6 - 6
aphrodite/engine/args_tools.py

@@ -3,7 +3,8 @@ import dataclasses
 from dataclasses import dataclass
 from typing import Optional, Tuple
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig)
+from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
+                                     SchedulerConfig)
 
 
 @dataclass
@@ -180,10 +181,9 @@ class EngineArgs:
                                    self.download_dir, self.load_format,
                                    self.dtype, self.seed, self.revision,
                                    self.max_model_len, self.quantization)
-        cache_config = CacheConfig(self.block_size,
-                                   self.gpu_memory_utilization,
-                                   self.swap_space, getattr(model_config.hf_config,
-                                                            'sliding_window', None))
+        cache_config = CacheConfig(
+            self.block_size, self.gpu_memory_utilization, self.swap_space,
+            getattr(model_config.hf_config, 'sliding_window', None))
         parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                          self.tensor_parallel_size,
                                          self.worker_use_ray)
@@ -218,4 +218,4 @@ class AsyncEngineArgs(EngineArgs):
                             help='max number of prompt characters or prompt '
                             'ID numbers being printed in log. '
                             'Default: unlimited.')
-        return parser
+        return parser

+ 2 - 1
aphrodite/engine/async_aphrodite.py

@@ -168,6 +168,7 @@ class RequestTracker:
     async def wait_for_new_requests(self):
         await self.new_requests_event.wait()
 
+
 class _AsyncAphrodite(AphroditeEngine):
     """Extension of AphroditeEngine to add async methods."""
 
@@ -490,4 +491,4 @@ class AsyncAphrodite:
                      log_stats=not engine_args.disable_log_stats,
                      max_log_len=engine_args.max_log_len,
                      start_engine_loop=start_engine_loop)
-        return engine
+        return engine

+ 5 - 5
aphrodite/modeling/hf_downloader.py

@@ -17,7 +17,6 @@ from aphrodite.common.logger import init_logger
 from aphrodite.modeling.quantization_utils import get_quant_class
 from aphrodite.modeling.quantization_utils.base import QuantizationConfig
 
-
 logger = init_logger(__name__)
 
 
@@ -92,7 +91,7 @@ def get_quant_config(
     if quantization == "gptq" and hasattr(hf_config, "quantization_config"):
         config = hf_config.quantization_config
         return get_quant_class(quantization).from_config(config)
-    
+
     is_local = os.path.isdir(model_name_or_path)
     if not is_local:
         # Download the config files.
@@ -307,7 +306,7 @@ def load_tensor_parallel_weights(
             else:
                 index = [slice(None)] * (len(loaded_weight.get_shape()) -
                                          1) + [slice(start_idx, end_idx)]
-                loaded_weight = loaded_weight[index] 
+                loaded_weight = loaded_weight[index]
             break
 
     loaded_weight = convert_pyslice_to_tensor(loaded_weight)
@@ -340,7 +339,8 @@ def get_parallel_weight(model: torch.nn.Module):
         row_weight_suffixes = ["weight"]
         ignore_weight_suffixes = []
     else:
-        column_weight_suffixes = model.quant_config.get_column_tp_tensor_names()
+        column_weight_suffixes = model.quant_config.get_column_tp_tensor_names(
+        )
         row_weight_suffixes = model.quant_config.get_row_tp_tensor_names()
         ignore_weight_suffixes = model.quant_config.get_ignore_tensor_names()
 
@@ -357,4 +357,4 @@ def get_parallel_weight(model: torch.nn.Module):
         for layer in model.parallel_vocab_layers:
             for suffix in ["weight", "bias"]:
                 column_parallel_weights.append(f"{layer}.{suffix}")
-    return column_parallel_weights, row_parallel_weights, ignore_weight_suffixes
+    return column_parallel_weights, row_parallel_weights, ignore_weight_suffixes

+ 7 - 3
aphrodite/modeling/layers/activation.py

@@ -4,7 +4,6 @@ import torch.nn as nn
 from aphrodite import activation_ops
 
 
-
 class SiluAndMul(nn.Module):
     """An activation function for SwiGLU.
 
@@ -27,13 +26,15 @@ class SiluAndMul(nn.Module):
         activation_ops.silu_and_mul(out, x)
         return out
 
+
 class NewGELU(nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         out = torch.empty_like(x)
         activation_ops.gelu_new(out, x)
         return out
-    
+
+
 class FastGELU(nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -41,6 +42,7 @@ class FastGELU(nn.Module):
         activation_ops.gelu_fast(out, x)
         return out
 
+
 _ACTIVATION_REGISTRY = {
     "gelu": nn.GELU(),
     "gelu_new": NewGELU(),
@@ -49,9 +51,11 @@ _ACTIVATION_REGISTRY = {
     "relu": nn.ReLU(),
 }
 
+
 def get_act_fn(act_fn: str) -> nn.Module:
     """Get an activation function by name."""
     act_fn = act_fn.lower()
     if act_fn in _ACTIVATION_REGISTRY:
         return _ACTIVATION_REGISTRY[act_fn]
-    raise ValueError(f"Activation function {act_fn!r} is currently not supported.")
+    raise ValueError(
+        f"Activation function {act_fn!r} is currently not supported.")

+ 5 - 4
aphrodite/modeling/layers/layernorm.py

@@ -4,6 +4,7 @@ import torch.nn as nn
 
 from aphrodite import layernorm_ops
 
+
 class RMSNorm(nn.Module):
     """Root mean square normalization.
 
@@ -12,9 +13,9 @@ class RMSNorm(nn.Module):
     """
 
     def __init__(
-        self,
-        hidden_size: int,
-        eps: float = 1e-6, # the epsilon value used by llama models
+            self,
+            hidden_size: int,
+            eps: float = 1e-6,  # the epsilon value used by llama models
     ) -> None:
         super().__init__()
         self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -28,4 +29,4 @@ class RMSNorm(nn.Module):
             self.weight.data,
             self.variance_epsilon,
         )
-        return out
+        return out

+ 4 - 4
aphrodite/modeling/layers/quantized_linear/__init__.py

@@ -2,10 +2,10 @@ from torch import nn
 
 from aphrodite.modeling.layers.quantized_linear.awq import (
     AWQColumnParallelLinear, AWQRowParallelLinear)
-from aphrodite.modeling.layers.quantized_linear.gptq import(
+from aphrodite.modeling.layers.quantized_linear.gptq import (
     GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)
-from aphrodite.modeling.megatron.layers import (
-    ColumnParallelLinear, RowParallelLinear)
+from aphrodite.modeling.megatron.layers import (ColumnParallelLinear,
+                                                RowParallelLinear)
 
 _QUANTIZED_LINEAR_REGISTRY = {
     "awq": (AWQColumnParallelLinear, AWQRowParallelLinear, None),
@@ -57,4 +57,4 @@ class ParallelLinear:
             raise ValueError(f"No quantized linear is found for {name}")
 
         quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
-        return quant_linear_cls(*args, **kwargs)
+        return quant_linear_cls(*args, **kwargs)

+ 3 - 2
aphrodite/modeling/layers/quantized_linear/awq.py

@@ -5,8 +5,9 @@ import torch
 from torch.nn.parameter import Parameter
 
 from aphrodite import quantization_ops
-from aphrodite.modeling.megatron.layers import (
-    ColumnParallelLinear, RowParallelLinear)
+from aphrodite.modeling.megatron.layers import (ColumnParallelLinear,
+                                                RowParallelLinear)
+
 
 class AWQColumnParallelLinear(ColumnParallelLinear):
 

+ 2 - 2
aphrodite/modeling/layers/quantized_linear/gptq.py

@@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter
 
 from aphrodite import quantization_ops
 from aphrodite.modeling.megatron.layers import (ColumnParallelLinear,
-                                                       RowParallelLinear)
+                                                RowParallelLinear)
 
 
 class GPTQLinear(torch.nn.Module):
@@ -262,4 +262,4 @@ class GPTQRowParallelLinear(RowParallelLinear):
                                                  self.scales.float(),
                                                  self.qzeros, self.g_idx)
             output = output.half()
-        return output.reshape(out_shape)
+        return output.reshape(out_shape)

+ 7 - 6
aphrodite/modeling/layers/quantized_linear/utils.py

@@ -6,6 +6,7 @@ from aphrodite import quantization_ops
 from aphrodite.modeling.layers.quantized_linear.gptq import (
     GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)
 
+
 def quant_post_init(model, max_input_length: Optional[int] = None):
     device_to_buffers_size = {}
 
@@ -26,7 +27,7 @@ def quant_post_init(model, max_input_length: Optional[int] = None):
             device_to_buffers_size[device]["max_dq_buffer_size"] = max(
                 device_to_buffers_size[device]["max_dq_buffer_size"],
                 submodule.qweight.numel() * 8)
-            
+
             in_features = submodule.input_size_per_partition if isinstance(
                 submodule, GPTQRowParallelLinear) else submodule.input_size
             out_features = submodule.output_size_per_partition if isinstance(
@@ -36,7 +37,7 @@ def quant_post_init(model, max_input_length: Optional[int] = None):
                 device_to_buffers_size[device]["max_inner_outer_dim"] = max(
                     device_to_buffers_size[device]["max_inner_outer_dim"],
                     in_features, out_features)
-    
+
     if model_uses_exllama:
         device_to_buffers = {}
         max_input_len = max_input_length if use_act_order else 1
@@ -64,20 +65,20 @@ def quant_post_init(model, max_input_length: Optional[int] = None):
             quantization_ops.gptq_prepare_buffers(device,
                                                   buffers["temp_state"],
                                                   buffers["temp_dq"])
-            
+
         matmul_recons_thd = 8
         matmul_fused_remap = False
         matmul_no_half2 = False
         quantization_ops.gptq_set_tuning_params(matmul_recons_thd,
                                                 matmul_fused_remap,
                                                 matmul_no_half2)
-        
+
         # the buffers need to have been initialized first before calling make_q4
         for _, submodule in model.named_modules():
             if isinstance(
-                submodule,
+                    submodule,
                 (GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)):
                 submodule.post_init()
 
         torch.cuda.empty_cache()
-    return model
+    return model

+ 28 - 21
aphrodite/modeling/layers/rotary_embedding.py

@@ -169,7 +169,8 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
         sin = freqs.sin()
         cache = torch.cat((cos, sin), dim=-1)
         return cache
-    
+
+
 def _yarn_find_correction_dim(num_rotations: int,
                               dim: int,
                               base: float = 10000,
@@ -178,6 +179,7 @@ def _yarn_find_correction_dim(num_rotations: int,
                            (num_rotations * 2 * math.pi))) / (2 *
                                                               math.log(base))
 
+
 def _yarn_find_correction_range(low_rot: int,
                                 high_rot: int,
                                 dim: int,
@@ -186,8 +188,10 @@ def _yarn_find_correction_range(low_rot: int,
     low = math.floor(
         _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
     high = math.ceil(
-        _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
-    return max(low, 0), min(high, dim - 1) # clamp values just in case
+        _yarn_find_correction_dim(high_rot, dim, base,
+                                  max_position_embeddings))
+    return max(low, 0), min(high, dim - 1)  # clamp values just in case
+
 
 def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
                            dtype: torch.dtype,
@@ -200,6 +204,7 @@ def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
     ramp_func = torch.clamp(linear_func, 0, 1)
     return ramp_func
 
+
 def _yarn_get_mscale(scale: float = 1) -> float:
     if scale <= 1:
         return 1.0
@@ -210,18 +215,18 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
     """Rotary embedding extended with YaRN method (Peng et al.)"""
 
     def __init__(
-            self,
-            head_size: int,
-            rotary_dim: int,
-            max_position_embeddings: int,
-            base: int,
-            is_neox_style: bool,
-            scaling_factor: float,
-            *,
-            extrapolation_factor: float = 1,
-            attn_factor: float = 1,
-            beta_fast: float = 32,
-            beta_slow: float = 1,
+        self,
+        head_size: int,
+        rotary_dim: int,
+        max_position_embeddings: int,
+        base: int,
+        is_neox_style: bool,
+        scaling_factor: float,
+        *,
+        extrapolation_factor: float = 1,
+        attn_factor: float = 1,
+        beta_fast: float = 32,
+        beta_slow: float = 1,
     ) -> None:
         self.scaling_factor = scaling_factor
         self.extrapolation_factor = extrapolation_factor
@@ -229,9 +234,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         self.beta_fast = beta_fast
         self.beta_slow = beta_slow
         self.mscale = float(
-            _yarn_get_mscale(self.scaling_factor) * attn_factor) # get n-d magnitude scaling corrected for interpolation
-        super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style)
-    
+            _yarn_get_mscale(self.scaling_factor) * attn_factor
+        )  # get n-d magnitude scaling corrected for interpolation
+        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
+                         is_neox_style)
+
     def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
         pos_freqs = self.base**(torch.arange(
             0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
@@ -242,14 +249,14 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
                                                 self.rotary_dim, self.base,
                                                 self.max_position_embeddings)
-        
+
         inv_freq_mask = (1 - _yarn_linear_ramp_mask(
             low, high, self.rotary_dim // 2, dtype=torch.float,
             device="cuda")) * self.extrapolation_factor
         inv_freq = inv_freq_interpolation * (
             1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
         return inv_freq
-    
+
     def _compute_cos_sin_cache(self) -> torch.Tensor:
         inv_freq = self._compute_inv_freq(self.scaling_factor)
         t = torch.arange(self.max_position_embeddings * self.scaling_factor,
@@ -259,4 +266,4 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         cos = (freqs.cos() * self.mscale)
         sin = (freqs.sin() * self.mscale)
         cache = torch.cat((cos, sin), dim=-1)
-        return cache
+        return cache

+ 71 - 56
aphrodite/modeling/layers/sampler.py

@@ -7,12 +7,11 @@ import torch.nn as nn
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_gather
-)
+    tensor_model_parallel_all_gather)
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
-from aphrodite.common.sequence import (
-    PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData,
-    SequenceGroupOutputs, SequenceOutputs)
+from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
+                                       SamplerOutput, SequenceData,
+                                       SequenceGroupOutputs, SequenceOutputs)
 from aphrodite.common.sequence import SamplerOutput, SequenceOutputs, SequenceData
 
 _SAMPLING_EPS = 1e-5
@@ -54,18 +53,20 @@ class Sampler(nn.Module):
         # Apply presence and frequency penalties.
         output_tokens = _get_output_tokens(input_metadata)
         assert len(output_tokens) == logits.shape[0]
-        presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(input_metadata)
+        presence_penalties, frequency_penalties, repetition_penalties = _get_penalties(
+            input_metadata)
         assert len(presence_penalties) == logits.shape[0]
         assert len(frequency_penalties) == logits.shape[0]
-        logits = _apply_penalties(logits, output_tokens,
-                                  presence_penalties, frequency_penalties, repetition_penalties,
+        logits = _apply_penalties(logits, output_tokens, presence_penalties,
+                                  frequency_penalties, repetition_penalties,
                                   self.vocab_size)
-        
+
         banned_tokens = _get_custom_token_bans(input_metadata)
         assert len(banned_tokens) == logits.shape[0]
         logits = _apply_token_bans(logits, banned_tokens)
-        
-        logits = _apply_logits_processors(input_metadata, logits, output_tokens)
+
+        logits = _apply_logits_processors(input_metadata, logits,
+                                          output_tokens)
 
         # Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
         eta_cutoffs = _get_eta_cutoffs(input_metadata)
@@ -101,7 +102,8 @@ class Sampler(nn.Module):
             logits.div_(t.unsqueeze(dim=1))
 
         # Apply top-p, top-k, and top-a truncation.
-        top_ps, top_ks, top_as = _get_top_a_top_p_top_k(input_metadata, self.vocab_size)
+        top_ps, top_ks, top_as = _get_top_a_top_p_top_k(
+            input_metadata, self.vocab_size)
         assert len(top_ps) == len(top_ks) == logits.shape[0]
         do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
         do_top_k = any(k != self.vocab_size for k in top_ks)
@@ -141,7 +143,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
 def _prune_hidden_states(
     hidden_states: torch.Tensor,
     input_metadata: InputMetadata,
-) -> torch.Tensor:   
+) -> torch.Tensor:
     selected_token_indices: List[int] = []
     start_idx = 0
     for i, seq_group in enumerate(input_metadata.seq_groups):
@@ -166,6 +168,7 @@ def _prune_hidden_states(
     hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
     return hidden_states.index_select(0, selected_token_indices)
 
+
 def _get_penalties(
         input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
     # Collect the presence and frequency penalties.
@@ -181,8 +184,10 @@ def _get_penalties(
             frequency_penalties += [0] * (prompt_len - 1)
             repetition_penalties += [0] * (prompt_len - 1)
         presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
-        frequency_penalties += [sampling_params.frequency_penalty] * len(seq_ids)
-        repetition_penalties += [sampling_params.repetition_penalty] * len(seq_ids)
+        frequency_penalties += [sampling_params.frequency_penalty
+                                ] * len(seq_ids)
+        repetition_penalties += [sampling_params.repetition_penalty
+                                 ] * len(seq_ids)
     return presence_penalties, frequency_penalties, repetition_penalties
 
 
@@ -215,14 +220,12 @@ def _get_custom_token_bans(input_metadata: InputMetadata) -> List[List[int]]:
     return banned_tokens
 
 
-def _apply_logits_processors(
-    input_metadata: InputMetadata,
-    logits: torch.Tensor,
-    output_tokens: List[List[int]]
-) -> torch.Tensor:
+def _apply_logits_processors(input_metadata: InputMetadata,
+                             logits: torch.Tensor,
+                             output_tokens: List[List[int]]) -> torch.Tensor:
     seq_offset = 0
 
-    for seq_ids,sampling_params in input_metadata.seq_groups:
+    for seq_ids, sampling_params in input_metadata.seq_groups:
         seq_end = seq_offset + len(seq_ids)
 
         for proc in sampling_params.logits_processors:
@@ -232,6 +235,7 @@ def _apply_logits_processors(
 
     return logits
 
+
 def _apply_penalties(
     logits: torch.Tensor,
     output_tokens: List[List[int]],
@@ -244,9 +248,9 @@ def _apply_penalties(
     for i in range(num_seqs):
         if not output_tokens[i]:
             continue
-        if (abs(presence_penalties[i]) < _SAMPLING_EPS and
-            abs(frequency_penalties[i]) < _SAMPLING_EPS and
-            repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
+        if (abs(presence_penalties[i]) < _SAMPLING_EPS
+                and abs(frequency_penalties[i]) < _SAMPLING_EPS
+                and repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
             continue
         break
     else:
@@ -278,8 +282,8 @@ def _apply_penalties(
                                       dtype=logits.dtype,
                                       device=logits.device)
     repetition_penalties = torch.tensor(repetition_penalties,
-                                      dtype=logits.dtype,
-                                      device=logits.device)
+                                        dtype=logits.dtype,
+                                        device=logits.device)
 
     # We follow the definition in OpenAI API.
     # Refer to https://platform.openai.com/docs/api-reference/parameter-details
@@ -289,13 +293,16 @@ def _apply_penalties(
 
     # Effectively: If token is present and logit is positive, divide logit by rep_pen.
     #              If token is present and logit is negative, multiply logit by rep_pen.
-    logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits > 0)
-    logits += logits * (repetition_penalties.unsqueeze(dim=1) - 1) * presence_mask * (logits < 0)
+    logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) -
+                        1) * presence_mask * (logits > 0)
+    logits += logits * (repetition_penalties.unsqueeze(dim=1) -
+                        1) * presence_mask * (logits < 0)
 
     return logits
 
 
-def _apply_token_bans(logits: torch.Tensor, banned_tokens: List[List[int]]) -> torch.Tensor:
+def _apply_token_bans(logits: torch.Tensor,
+                      banned_tokens: List[List[int]]) -> torch.Tensor:
     for i, banned_token_ids in enumerate(banned_tokens):
         if not banned_token_ids:
             continue
@@ -340,7 +347,7 @@ def _get_top_a_top_p_top_k(
             prompt_len = input_metadata.prompt_lens[i]
             top_ps += [sampling_params.top_p] * (prompt_len - 1)
             top_ks += [top_k] * (prompt_len - 1)
-            top_as += [sampling_params.top_a] * (prompt_len - 1) 
+            top_as += [sampling_params.top_a] * (prompt_len - 1)
         top_ps += [sampling_params.top_p] * len(seq_ids)
         top_ks += [top_k] * len(seq_ids)
         top_as += [sampling_params.top_a] * len(seq_ids)
@@ -415,11 +422,13 @@ def _apply_top_a_top_p_top_k(
     probs_sort = logits_sort.softmax(dim=-1)
     probs_sum = probs_sort.cumsum(dim=-1)
     top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
-    top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)) # Cull logits below the top-a threshold
-    top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
-    top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
+    top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)
+                   )  # Cull logits below the top-a threshold
+    top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(
+        dim=1))  # Cull logits above the top-p summation threshold
+    top_ap_mask[:, 0] = False  # Guarantee at least one token is pickable
     logits_sort[top_ap_mask] = -float("inf")
-    
+
     # Apply top-k.
     # Create a mask for the top-k elements.
     top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
@@ -433,6 +442,7 @@ def _apply_top_a_top_p_top_k(
                           index=torch.argsort(logits_idx, dim=-1))
     return logits
 
+
 def _apply_tfs(
     logits: torch.Tensor,
     tfss: List[float],
@@ -446,14 +456,16 @@ def _apply_tfs(
     tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
 
     tfs_mask = torch.cat(
-            (
-                torch.zeros(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
-                tfs_mask,
-                torch.ones(logits.shape[0], 1, dtype=torch.bool, device=logits.device),
-            ),
-            dim=-1,
-        )
-    
+        (
+            torch.zeros(
+                logits.shape[0], 1, dtype=torch.bool, device=logits.device),
+            tfs_mask,
+            torch.ones(
+                logits.shape[0], 1, dtype=torch.bool, device=logits.device),
+        ),
+        dim=-1,
+    )
+
     logits_sort[tfs_mask] = -float("inf")
     logits = torch.gather(logits_sort,
                           dim=-1,
@@ -462,21 +474,22 @@ def _apply_tfs(
     return logits
 
 
-
 def _apply_eta_cutoff(
     logits: torch.Tensor,
     eta_cutoffs: List[float],
 ) -> torch.Tensor:
-    eta = torch.tensor(eta_cutoffs, dtype=logits.dtype, device=logits.device) * 1e-4
+    eta = torch.tensor(eta_cutoffs, dtype=logits.dtype,
+                       device=logits.device) * 1e-4
     shifted_logits = torch.log_softmax(logits, dim=-1)
     probs = shifted_logits.exp()
 
     neg_entropy = (probs * shifted_logits).nansum(dim=-1)
-    eps = torch.min(eta, torch.sqrt(eta)*torch.exp(neg_entropy)).unsqueeze(dim=1)
+    eps = torch.min(eta,
+                    torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
 
     eta_mask = probs < eps
 
-    if(torch.all(eta_mask)): # guard against nulling out all the logits
+    if (torch.all(eta_mask)):  # guard against nulling out all the logits
         topk_prob, _ = torch.max(probs, dim=-1)
         eta_mask = probs < topk_prob
 
@@ -488,12 +501,14 @@ def _apply_epsilon_cutoff(
     logits: torch.Tensor,
     epsilon_cutoffs: List[float],
 ) -> torch.Tensor:
-    eps = torch.tensor(epsilon_cutoffs, dtype=logits.dtype, device=logits.device).unsqueeze(dim=1)
+    eps = torch.tensor(epsilon_cutoffs,
+                       dtype=logits.dtype,
+                       device=logits.device).unsqueeze(dim=1)
     probs = logits.softmax(dim=-1)
 
     eps_mask = probs < (eps * 1e-4)
 
-    if(torch.all(eps_mask)): # guard against nulling out all the logits
+    if (torch.all(eps_mask)):  # guard against nulling out all the logits
         topk_prob, _ = torch.max(probs, dim=-1)
         eps_mask = probs < topk_prob
 
@@ -515,17 +530,16 @@ def _apply_typical_sampling(
     _, indices = torch.sort(surprisal_deviations)
     reordered_probs = probs.gather(-1, indices)
     typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
-    
+
     min_tokens_to_keep = 1
     # Keep at least min_tokens_to_keep
     typ_mask_sorted[..., :min_tokens_to_keep] = 0
 
-    typ_mask = typ_mask_sorted.scatter(
-        1, indices, typ_mask_sorted
-    )
+    typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
     logits[typ_mask] = -float("inf")
     return logits
 
+
 def _greedy_sample(
     selected_seq_groups: List[Tuple[List[int], SamplingParams]],
     logprobs: torch.Tensor,
@@ -680,12 +694,13 @@ def _sample(
                                                  category_logprobs)
         else:
             raise ValueError(f"Unsupported sampling type: {sampling_type}")
-        
+
         sample_results_dict.update(zip(seq_group_ids, sample_results))
 
         sample_results = [
-        sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
-    ]
+            sample_results_dict[i]
+            for i in range(len(input_metadata.seq_groups))
+        ]
     return sample_results
 
 
@@ -822,4 +837,4 @@ def _build_sampler_output(
                 SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
         sampler_output.append(
             SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
-    return sampler_output
+    return sampler_output

+ 6 - 6
aphrodite/modeling/loader.py

@@ -19,11 +19,10 @@ _MODEL_REGISTRY = {
 }
 
 _MODEL_CLASSES_SUPPORT_QUANTIZATION = {
-    "awq": [
-        LlamaForCausalLM, MistralForCausalLM
-    ],
+    "awq": [LlamaForCausalLM, MistralForCausalLM],
     "gptq": [
-        LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM, MistralForCausalLM
+        LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM,
+        MistralForCausalLM
     ],
 }
 
@@ -53,7 +52,8 @@ def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module:
     # Get the quantization config.
     quant_config = None
     if model_config.quantization is not None:
-        if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION[model_config.quantization]:
+        if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION[
+                model_config.quantization]:
             raise ValueError(
                 f"Quantization is not supported for {model_class}.")
         quant_config = get_quant_config(model_config.quantization,
@@ -96,4 +96,4 @@ def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module:
             model = model.cuda()
         if model_config.quantization is not None:
             quant_post_init(model, max_tokens)
-    return model.eval()
+    return model.eval()

+ 1 - 2
aphrodite/modeling/megatron/parallel_state.py

@@ -2,7 +2,6 @@
 # Copyright 2023 The vLLM team.
 # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-
 """Model and data parallel groups."""
 
 import torch
@@ -177,4 +176,4 @@ def destroy_model_parallel():
     global _PIPELINE_MODEL_PARALLEL_GROUP
     _PIPELINE_MODEL_PARALLEL_GROUP = None
     global _PIPELINE_GLOBAL_RANKS
-    _PIPELINE_GLOBAL_RANKS = None
+    _PIPELINE_GLOBAL_RANKS = None

+ 3 - 2
aphrodite/modeling/metadata.py

@@ -5,6 +5,7 @@ from xformers.ops import AttentionBias
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import SequenceData
 
+
 class InputMetadata:
     """Metadata for input sequences. Used for PagedAttention.
 
@@ -41,7 +42,7 @@ class InputMetadata:
         self.to_cache = None
         if sliding_window is not None:
             # We need to keep the positions of sliding windows within
-            # the key/value tables, this is helpful to know which 
+            # the key/value tables, this is helpful to know which
             # elements we need to cache and where.
             to_cache, start_idx = [], 0
             for prompt_len in self.prompt_lens:
@@ -80,4 +81,4 @@ class InputMetadata:
                 f'max_context_len={self.max_context_len}), '
                 f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
                 f'block_tables={self.block_tables}), '
-                f'slot_mapping={self.slot_mapping}')
+                f'slot_mapping={self.slot_mapping}')

+ 1 - 2
aphrodite/modeling/models/__init__.py

@@ -3,10 +3,9 @@ from aphrodite.modeling.models.mistral import MistralForCausalLM
 from aphrodite.modeling.models.gpt_j import GPTJForCausalLM
 from aphrodite.modeling.models.gpt_neox import GPTNeoXForCausalLM
 
-
 __all__ = [
     "LlamaForCausalLM",
     "GPTJForCausalLM",
     "GPTNeoXForCausalLM",
     "MistralForCausalLM",
-]
+]

+ 2 - 3
aphrodite/modeling/models/gpt_j.py

@@ -35,8 +35,7 @@ from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
                                               load_tensor_parallel_weights)
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.layers import (
-    VocabParallelEmbedding)
+from aphrodite.modeling.megatron.layers import (VocabParallelEmbedding)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -265,4 +264,4 @@ class GPTJForCausalLM(nn.Module):
             param = state_dict[name]
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          self._column_parallel_weights,
-                                         self._row_parallel_weights, tp_rank)
+                                         self._row_parallel_weights, tp_rank)

+ 3 - 3
aphrodite/modeling/models/gpt_neox.py

@@ -36,8 +36,8 @@ from aphrodite.modeling.hf_downloader import (hf_model_weights_iterator,
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
 from aphrodite.modeling.megatron.layers import (VocabParallelEmbedding,
-                                                       ColumnParallelLinear,
-                                                       RowParallelLinear)
+                                                ColumnParallelLinear,
+                                                RowParallelLinear)
 from aphrodite.common.sequence import SamplerOutput
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -283,4 +283,4 @@ class GPTNeoXForCausalLM(nn.Module):
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          self._column_parallel_weights,
                                          self._row_parallel_weights,
-                                         tensor_model_parallel_rank)
+                                         tensor_model_parallel_rank)

+ 5 - 6
aphrodite/modeling/models/llama.py

@@ -39,8 +39,7 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.quantized_linear import ParallelLinear
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.layers import (
-    VocabParallelEmbedding)
+from aphrodite.modeling.megatron.layers import (VocabParallelEmbedding)
 from aphrodite.modeling.quantization_utils import QuantizationConfig
 from aphrodite.modeling.hf_downloader import (
     convert_pyslice_to_tensor, hf_model_weights_iterator,
@@ -132,7 +131,7 @@ class LlamaAttention(nn.Module):
             self.head_dim,
             self.scaling,
             base=self.rope_theta,
-            rope_scaling = self.rope_scaling,
+            rope_scaling=self.rope_scaling,
             max_position=self.max_position_embeddings,
             rotary_dim=self.head_dim,
             num_kv_heads=self.num_kv_heads)
@@ -229,8 +228,8 @@ class LlamaModel(nn.Module):
         self.vocab_size = config.vocab_size
 
         vocab_size = ((config.vocab_size + 63) // 64) * 64
-        self.embed_tokens = VocabParallelEmbedding(
-            vocab_size, config.hidden_size)
+        self.embed_tokens = VocabParallelEmbedding(vocab_size,
+                                                   config.hidden_size)
         self.layers = nn.ModuleList([
             LlamaDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
@@ -403,4 +402,4 @@ class LlamaForCausalLM(nn.Module):
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          column_parallel_weights,
                                          row_parallel_weights,
-                                         tensor_model_parallel_rank)
+                                         tensor_model_parallel_rank)

+ 3 - 3
aphrodite/modeling/models/mistral.py

@@ -294,7 +294,7 @@ class MistralForCausalLM(nn.Module):
             if "rotary_emb.inv_freq" in name:
                 continue
             if any(name.endswith(suffix) for suffix in ignore_weight_suffixes):
-                continue            
+                continue
 
             is_packed = False
             is_transposed = False
@@ -355,7 +355,7 @@ class MistralForCausalLM(nn.Module):
                 break
             if is_gate_up_weight:
                 continue
-            
+
             if name not in state_dict:
                 continue
             param = state_dict[name]
@@ -370,4 +370,4 @@ class MistralForCausalLM(nn.Module):
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          column_parallel_weights,
                                          row_parallel_weights,
-                                         tensor_model_parallel_rank)
+                                         tensor_model_parallel_rank)

+ 2 - 2
aphrodite/modeling/quantization_utils/awq.py

@@ -69,6 +69,6 @@ class AWQConfig(QuantizationConfig):
 
     def get_row_tp_tensor_names(self) -> List[str]:
         return ["qweight", "qzeros", "scales"]
-    
+
     def get_column_tp_tensor_names(self) -> List[str]:
-        return ["qweight", "qzeros", "scales", "bias"]
+        return ["qweight", "qzeros", "scales", "bias"]

+ 3 - 3
aphrodite/modeling/quantization_utils/base.py

@@ -73,9 +73,9 @@ class QuantizationConfig:
     @classmethod
     def get_row_tp_tensor_names(self) -> List[str]:
         raise NotImplementedError
-    
+
     def get_column_tp_tensor_names(self) -> List[str]:
         raise NotImplementedError
-    
+
     def get_ignore_tensor_names(self) -> List[str]:
-        return []
+        return []

+ 14 - 13
aphrodite/modeling/quantization_utils/gptq.py

@@ -6,6 +6,7 @@ from aphrodite.modeling.quantization_utils.base import QuantizationConfig
 
 
 class GPTQConfig(QuantizationConfig):
+
     def __init__(
         self,
         weight_bits: int,
@@ -18,58 +19,58 @@ class GPTQConfig(QuantizationConfig):
         self.pack_factor = 32 // self.weight_bits
         if self.weight_bits != 4:
             raise ValueError(
-                f"Currently only 4-bit quant is supported for GPTQ, you passed {self.weight_bits} bits.")
-    
+                f"Currently only 4-bit quant is supported for GPTQ, you passed {self.weight_bits} bits."
+            )
+
     def __repr__(self) -> str:
         return (f"GPTQConfig(weight_bits={self.weight_bits}), "
                 f"group_size={self.group_size}, "
                 f"desc_act={self.desc_act}")
-    
+
     @classmethod
     def get_name(cls) -> str:
         return "gptq"
-    
+
     @classmethod
     def get_supported_act_dtypes(cls) -> List[torch.dtype]:
         return [torch.half]
-    
+
     @classmethod
     def get_min_capability(cls) -> int:
         return 60
-    
+
     @classmethod
     def get_config_filenames(cls) -> List[str]:
         return [
             "quantize_config.json",
         ]
-    
+
     @classmethod
     def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
         weight_bits = cls.get_from_keys(config, ["bits"])
         group_size = cls.get_from_keys(config, ["group_size"])
         desc_act = cls.get_from_keys(config, ["desc_act"])
         return cls(weight_bits, group_size, desc_act)
-    
+
     @classmethod
     def get_packed_tensor_names(cls) -> List[str]:
         return ["qzeros"]
-    
+
     @classmethod
     def get_transposed_tensor_names(cls) -> List[str]:
         return ["qweight", "qzeros", "scales"]
-    
+
     def get_row_tp_tensor_names(self) -> List[str]:
         if self.desc_act and self.group_size != -1:
             return ["qweight", "g_idx"]
         if self.group_size == -1:
             return ["qweight"]
         return ["qweight", "qzeros", "scales"]
-    
+
     def get_column_tp_tensor_names(self) -> List[str]:
         return ["qweight", "qzeros", "scales", "bias"]
-    
+
     def get_ignore_tensor_names(self) -> List[str]:
         if self.desc_act and self.group_size != -1:
             return []
         return ["g_idx"]
-        

+ 1 - 0
aphrodite/modeling/utils.py

@@ -3,6 +3,7 @@ import random
 import numpy as np
 import torch
 
+
 def set_random_seed(seed: int) -> None:
     random.seed(seed)
     np.random.seed(seed)

+ 4 - 1
aphrodite/processing/policy.py

@@ -2,6 +2,7 @@ from typing import List
 
 from aphrodite.common.sequence import SequenceGroup
 
+
 class Policy:
 
     def get_priority(
@@ -22,6 +23,7 @@ class Policy:
             reverse=True,
         )
 
+
 class FCFS(Policy):
 
     def get_priority(
@@ -31,6 +33,7 @@ class FCFS(Policy):
     ) -> float:
         return now - seq_group.arrival_time
 
+
 class PolicyFactory:
 
     _POLICY_REGISTRY = {
@@ -39,4 +42,4 @@ class PolicyFactory:
 
     @classmethod
     def get_policy(cls, policy_name: str, **kwargs) -> Policy:
-        return cls._POLICY_REGISTRY[policy_name](**kwargs)
+        return cls._POLICY_REGISTRY[policy_name](**kwargs)

+ 2 - 2
aphrodite/processing/scheduler.py

@@ -7,7 +7,7 @@ from aphrodite.processing.block_manager import BlockSpaceManager
 from aphrodite.processing.policy import PolicyFactory
 from aphrodite.common.logger import init_logger
 from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
-                           SequenceGroupMetadata, SequenceStatus)
+                                       SequenceGroupMetadata, SequenceStatus)
 
 logger = init_logger(__name__)
 
@@ -397,4 +397,4 @@ class Scheduler:
         mapping = self.block_manager.swap_out(seq_group)
         blocks_to_swap_out.update(mapping)
         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
-            seq.status = SequenceStatus.SWAPPED
+            seq.status = SequenceStatus.SWAPPED

+ 7 - 4
aphrodite/task_handler/worker.py

@@ -6,7 +6,7 @@ import torch
 import torch.distributed
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                         SchedulerConfig)
+                                     SchedulerConfig)
 from aphrodite.modeling import get_model, InputMetadata, set_random_seed
 from aphrodite.modeling.megatron.parallel_state import (
     initialize_model_parallel)
@@ -67,7 +67,8 @@ class Worker:
 
         # Initialize the model.
         set_random_seed(self.model_config.seed)
-        self.model = get_model(self.model_config, self.scheduler_config.max_num_batched_tokens)
+        self.model = get_model(self.model_config,
+                               self.scheduler_config.max_num_batched_tokens)
 
     @torch.inference_mode()
     def profile_num_available_blocks(
@@ -146,7 +147,7 @@ class Worker:
         else:
             max_seq_len = min(self.scheduler_config.max_model_len,
                               self.sliding_window)
-        
+
         _check_if_can_support_max_seq_len(max_seq_len, self.block_size)
 
         self.cache_engine = CacheEngine(self.cache_config, self.model_config,
@@ -401,6 +402,7 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
             f"available shared memory {max_shared_mem}). "
             "This will be fixed in a future release.")
 
+
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
     if torch_dtype == torch.bfloat16:
         compute_capability = torch.cuda.get_device_capability()
@@ -410,4 +412,5 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
                 "Bfloat16 is only supported on GPUs with compute capability "
                 f"of at least 8.0. You {gpu_name} GPU has compute capability "
                 f"{compute_capability[0]}.{compute_capability[1]}. Please "
-                "use the `--dtype float16` argument when launching the engine.")
+                "use the `--dtype float16` argument when launching the engine."
+            )

+ 3 - 1
aphrodite/transformers_utils/config.py

@@ -1,7 +1,9 @@
 from typing import Optional
 from transformers import AutoConfig, PretrainedConfig
 
-def get_config(model: str, trust_remote_code: bool,
+
+def get_config(model: str,
+               trust_remote_code: bool,
                revision: Optional[str] = None) -> PretrainedConfig:
     try:
         config = AutoConfig.from_pretrained(

+ 9 - 8
tests/kernels/conftest.py

@@ -5,13 +5,13 @@ import torch
 
 
 def create_kv_caches(
-        num_blocks: int,
-        block_size: int,
-        num_layers: int,
-        num_heads: int,
-        head_size: int,
-        dtype: torch.dtype,
-        seed: int,
+    num_blocks: int,
+    block_size: int,
+    num_layers: int,
+    num_heads: int,
+    head_size: int,
+    dtype: torch.dtype,
+    seed: int,
 ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
     torch.random.manual_seed(seed)
     torch.cuda.manual_seed(seed)
@@ -37,6 +37,7 @@ def create_kv_caches(
         values_caches.append(values_cache)
     return key_caches, values_caches
 
+
 @pytest.fixture()
 def kv_cache_factory():
-    return create_kv_caches
+    return create_kv_caches

+ 4 - 2
tests/kernels/test_activation.py

@@ -8,13 +8,15 @@ from aphrodite import activation_ops
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
 NUM_TOKENS = [7, 38, 2048]
-D = [512, 4096, 5120, 13824] # arbitrary values for testing
+D = [512, 4096, 5120, 13824]  # arbitrary values for testing
 SEEDS = [0]
 
+
 def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
     x1, x2 = x.chunk(chunks=2, dim=1)
     return F.silu(x1) * x2
 
+
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("d", D)
 @pytest.mark.parametrize("dtype", DTYPES)
@@ -71,4 +73,4 @@ def test_gelu_fast(
     out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
     activation_ops.gelu_fast(out, x)
     ref_out = get_activation("gelu_fast")(x)
-    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
+    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

+ 7 - 3
tests/models/test_models.py

@@ -4,6 +4,7 @@ MODELS = [
     "EleutherAI/pythia-70m-deduped",
 ]
 
+
 @pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["half"])
 @pytest.mark.parametrize("max_tokens", [128])
@@ -20,13 +21,16 @@ def test_models(
     del hf_model
 
     aphrodite_model = aphrodite_runner(model, dtype=dtype)
-    aphrodite_outputs = aphrodite_model.generate_greedy(example_prompts, max_tokens)
+    aphrodite_outputs = aphrodite_model.generate_greedy(
+        example_prompts, max_tokens)
     del aphrodite_model
 
     for i in range(len(example_prompts)):
         hf_output_ids, hf_output_str = hf_outputs[i]
         aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
         assert hf_output_str == aphrodite_output_str, (
-            f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: {aphrodite_output_str!r}")
+            f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: {aphrodite_output_str!r}"
+        )
         assert hf_output_ids == aphrodite_output_ids, (
-            f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")
+            f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}"
+        )

+ 17 - 11
tests/samplers/test_samplers.py

@@ -19,18 +19,19 @@ class MockLogitsSampler(Sampler):
 
     def forward(self, *args, **kwargs):
         with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
-                    lambda x, y: x):
+                   lambda x, y: x):
             with patch("aphrodite.modeling.layers.sampler._get_logits",
-                      lambda *args, **kwargs: self.fake_logits):
+                       lambda *args, **kwargs: self.fake_logits):
                 return super().forward(*args, **kwargs)
-            
+
+
 def _prepare_test(
     batch_size: int
 ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
     vocab_size = 32000
     input_tensor = torch.rand((batch_size, 1024),
-                             device='cuda',
-                             dtype=torch.float16)
+                              device='cuda',
+                              dtype=torch.float16)
     fake_logits = torch.full((batch_size, vocab_size),
                              1e-2,
                              device=input_tensor.device,
@@ -40,8 +41,10 @@ def _prepare_test(
     worker.block_size = 16
     return input_tensor, fake_logits, sampler, worker
 
+
 RANDOM_SEEDS = list(range(128))
 
+
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 def test_sampler_all_greedy(seed: int):
     set_random_seed(seed)
@@ -58,16 +61,17 @@ def test_sampler_all_greedy(seed: int):
                 sampling_params=SamplingParams(temperature=0, ),
                 block_tables={0: [1]},
             ))
-    
+
     _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
     sampler_output = sampler(embedding=None,
-                            hidden_states=input_tensor,
-                            input_metadata=input_metadata)
+                             hidden_states=input_tensor,
+                             input_metadata=input_metadata)
     expected = torch.argmax(fake_logits, dim=-1)
     for i, sequence_output in enumerate(sampler_output):
         for nth_output in sequence_output:
             assert nth_output.output_token == expected[i].item()
 
+
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 def test_sampler_all_random(seed: int):
     set_random_seed(seed)
@@ -76,7 +80,7 @@ def test_sampler_all_random(seed: int):
 
     for i in range(batch_size):
         fake_logits[i, i] = 1e2
-    
+
     seq_group_metadata_list = []
     for i in range(batch_size):
         seq_group_metadata_list.append(
@@ -98,6 +102,7 @@ def test_sampler_all_random(seed: int):
         for nth_output in sequence_output:
             assert nth_output.output_token == i
 
+
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 def test_sampler_all_beam(seed: int):
     set_random_seed(seed)
@@ -123,6 +128,7 @@ def test_sampler_all_beam(seed: int):
             hidden_states=input_tensor,
             input_metadata=input_metadata)
 
+
 @pytest.mark.parametrize("seed", RANDOM_SEEDS)
 def test_sampler_mixed(seed: int):
     set_random_seed(seed)
@@ -156,7 +162,7 @@ def test_sampler_mixed(seed: int):
             sampling_params = SamplingParams(temperature=0,
                                              use_beam_search=True,
                                              best_of=2)
-        
+
         for idx in range(n):
             fake_logits[i, i + idx] = 1e2
             expected_tokens.append(i + idx)
@@ -168,7 +174,7 @@ def test_sampler_mixed(seed: int):
                 sampling_params=sampling_params,
                 block_tables={0: [1]},
             ))
-        
+
     _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
     sampler_output = sampler(embedding=None,
                              hidden_states=input_tensor,

+ 40 - 31
tests/serving.py

@@ -23,15 +23,10 @@ def sample_requests(
     with open(dataset_path) as f:
         dataset = json.load(f)
     # Filter out the conversations with less than 2 turns.
-    dataset = [
-        data for data in dataset
-        if len(data["conversations"]) >= 2
-    ]
+    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
     # Only keep the first two turns of each conversation.
-    dataset = [
-        (data["conversations"][0]["value"], data["conversations"][1]["value"])
-        for data in dataset
-    ]
+    dataset = [(data["conversations"][0]["value"],
+                data["conversations"][1]["value"]) for data in dataset]
 
     # Tokenize the prompts and completions.
     prompts = [prompt for prompt, _ in dataset]
@@ -120,7 +115,8 @@ async def send_request(
     timeout = aiohttp.ClientTimeout(total=3 * 3600)
     async with aiohttp.ClientSession(timeout=timeout) as session:
         while True:
-            async with session.post(api_url, headers=headers, json=pload) as response:
+            async with session.post(api_url, headers=headers,
+                                    json=pload) as response:
                 chunks = []
                 async for chunk, _ in response.content.iter_chunks():
                     chunks.append(chunk)
@@ -147,9 +143,9 @@ async def benchmark(
     tasks: List[asyncio.Task] = []
     async for request in get_request(input_requests, request_rate):
         prompt, prompt_len, output_len = request
-        task = asyncio.create_task(send_request(backend, api_url, prompt,
-                                                prompt_len, output_len,
-                                                best_of, use_beam_search))
+        task = asyncio.create_task(
+            send_request(backend, api_url, prompt, prompt_len, output_len,
+                         best_of, use_beam_search))
         tasks.append(task)
     await asyncio.gather(*tasks)
 
@@ -160,12 +156,14 @@ def main(args: argparse.Namespace):
     np.random.seed(args.seed)
 
     api_url = f"http://{args.host}:{args.port}/api/v1/generate"
-    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
+    tokenizer = get_tokenizer(args.tokenizer,
+                              trust_remote_code=args.trust_remote_code)
     input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 
     benchmark_start_time = time.perf_counter()
-    asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
-                          args.use_beam_search, args.request_rate))
+    asyncio.run(
+        benchmark(args.backend, api_url, input_requests, args.best_of,
+                  args.use_beam_search, args.request_rate))
     benchmark_end_time = time.perf_counter()
     benchmark_time = benchmark_end_time - benchmark_start_time
     print(f"Total time: {benchmark_time:.2f} s")
@@ -179,10 +177,8 @@ def main(args: argparse.Namespace):
         for prompt_len, output_len, latency in REQUEST_LATENCY
     ])
     print(f"Average latency per token: {avg_per_token_latency:.2f} s")
-    avg_per_output_token_latency = np.mean([
-        latency / output_len
-        for _, output_len, latency in REQUEST_LATENCY
-    ])
+    avg_per_output_token_latency = np.mean(
+        [latency / output_len for _, output_len, latency in REQUEST_LATENCY])
     print("Average latency per output token: "
           f"{avg_per_output_token_latency:.2f} s")
 
@@ -190,27 +186,40 @@ def main(args: argparse.Namespace):
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         description="Benchmark the online serving throughput.")
-    parser.add_argument("--backend", type=str, default="aphrodite",
+    parser.add_argument("--backend",
+                        type=str,
+                        default="aphrodite",
                         choices=["aphrodite", "tgi"])
     parser.add_argument("--host", type=str, default="localhost")
     parser.add_argument("--port", type=int, default=2242)
-    parser.add_argument("--dataset", type=str, required=True,
+    parser.add_argument("--dataset",
+                        type=str,
+                        required=True,
                         help="Path to the dataset.")
-    parser.add_argument("--tokenizer", type=str, required=True,
+    parser.add_argument("--tokenizer",
+                        type=str,
+                        required=True,
                         help="Name or path of the tokenizer.")
-    parser.add_argument("--best-of", type=int, default=1,
+    parser.add_argument("--best-of",
+                        type=int,
+                        default=1,
                         help="Generates `best_of` sequences per prompt and "
-                             "returns the best one.")
+                        "returns the best one.")
     parser.add_argument("--use-beam-search", action="store_true")
-    parser.add_argument("--num-prompts", type=int, default=1000,
+    parser.add_argument("--num-prompts",
+                        type=int,
+                        default=1000,
                         help="Number of prompts to process.")
-    parser.add_argument("--request-rate", type=float, default=float("inf"),
+    parser.add_argument("--request-rate",
+                        type=float,
+                        default=float("inf"),
                         help="Number of requests per second. If this is inf, "
-                             "then all the requests are sent at time 0. "
-                             "Otherwise, we use Poisson process to synthesize "
-                             "the request arrival times.")
+                        "then all the requests are sent at time 0. "
+                        "Otherwise, we use Poisson process to synthesize "
+                        "the request arrival times.")
     parser.add_argument("--seed", type=int, default=0)
-    parser.add_argument('--trust-remote-code', action='store_true',
+    parser.add_argument('--trust-remote-code',
+                        action='store_true',
                         help='trust remote code from huggingface')
     args = parser.parse_args()
-    main(args)
+    main(args)

+ 6 - 4
tests/throughput.py

@@ -12,6 +12,7 @@ from tqdm import tqdm
 from aphrodite import LLM, SamplingParams
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
+
 def sample_requests(
     dataset_path: str,
     num_requests: int,
@@ -170,9 +171,10 @@ def main(args: argparse.Namespace):
 
     if args.backend == "aphrodite":
         elapsed_time = run_aphrodite(requests, args.model, args.tokenizer,
-                                args.quantization, args.tensor_parallel_size,
-                                args.seed, args.n, args.use_beam_search,
-                                args.trust_remote_code, args.dtype)
+                                     args.quantization,
+                                     args.tensor_parallel_size, args.seed,
+                                     args.n, args.use_beam_search,
+                                     args.trust_remote_code, args.dtype)
     elif args.backend == "hf":
         assert args.tensor_parallel_size == 1
         elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -245,4 +247,4 @@ if __name__ == "__main__":
     if args.tokenizer is None:
         args.tokenizer = args.model
 
-    main(args)
+    main(args)