Parcourir la source

chore: estimated input speed for tqdm

AlpinDale il y a 7 mois
Parent
commit
458c8b5e33
1 fichiers modifiés avec 12 ajouts et 5 suppressions
  1. 12 5
      aphrodite/endpoints/llm.py

+ 12 - 5
aphrodite/endpoints/llm.py

@@ -530,11 +530,13 @@ class LLM:
                 total=num_requests,
                 desc="Processed prompts",
                 dynamic_ncols=True,
-                postfix=f"Generation Speed: {0:.2f} toks/s",
+                postfix=(f"estimated speed input: {0:.2f} toks/s, "
+                         f"output: {0:.2f} toks/s"),
             )
         # Run the engine.
         outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
-        total_toks = 0
+        total_in_toks = 0
+        total_out_toks = 0
         while self.llm_engine.has_unfinished_requests():
             step_outputs = self.llm_engine.step()
             for output in step_outputs:
@@ -543,10 +545,15 @@ class LLM:
                     if use_tqdm:
                         if isinstance(output, RequestOutput):
                             # Calculate tokens only for RequestOutput
-                            total_toks += sum(
+                            total_in_toks += len(output.prompt_token_ids)
+                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
+                            total_out_toks += sum(
                                 len(stp.token_ids) for stp in output.outputs)
-                            spd = total_toks / pbar.format_dict["elapsed"]
-                            pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
+                            out_spd = total_out_toks / pbar.format_dict[
+                                "elapsed"]
+                            pbar.postfix = (
+                                f"estimated speed input: {in_spd:.2f} toks/s, "
+                                f"output: {out_spd:.2f} toks/s")
                         pbar.update(1)
         if use_tqdm:
             pbar.close()