serving_embedding.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import time
  2. from typing import AsyncIterator, List, Tuple, Optional
  3. from fastapi import Request
  4. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  5. from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
  6. EmbeddingResponse,
  7. EmbeddingResponseData,
  8. UsageInfo)
  9. from aphrodite.endpoints.openai.serving_completions import parse_prompt_format
  10. from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
  11. from aphrodite.common.outputs import EmbeddingRequestOutput
  12. from aphrodite.common.utils import merge_async_iterators, random_uuid
  13. TypeTokenIDs = List[int]
  14. def request_output_to_embedding_response(
  15. final_res_batch: List[EmbeddingRequestOutput],
  16. request_id: str,
  17. created_time: int,
  18. model_name: str,
  19. ) -> EmbeddingResponse:
  20. data = []
  21. num_prompt_tokens = 0
  22. for idx, final_res in enumerate(final_res_batch):
  23. assert final_res is not None
  24. prompt_token_ids = final_res.prompt_token_ids
  25. embedding_data = EmbeddingResponseData(
  26. index=idx, embedding=final_res.outputs.embedding)
  27. data.append(embedding_data)
  28. num_prompt_tokens += len(prompt_token_ids)
  29. usage = UsageInfo(
  30. prompt_tokens=num_prompt_tokens,
  31. total_tokens=num_prompt_tokens,
  32. )
  33. return EmbeddingResponse(
  34. id=request_id,
  35. created=created_time,
  36. model=model_name,
  37. data=data,
  38. usage=usage,
  39. )
  40. class OpenAIServingEmbedding(OpenAIServing):
  41. def __init__(self,
  42. engine: AsyncAphrodite,
  43. served_model_names: List[str],
  44. lora_modules: Optional[List[LoRA]] = None):
  45. super().__init__(engine=engine,
  46. served_model_names=served_model_names,
  47. lora_modules=lora_modules)
  48. async def create_embedding(self, request: EmbeddingRequest,
  49. raw_request: Request):
  50. """Completion API similar to OpenAI's API.
  51. See https://platform.openai.com/docs/api-reference/embeddings/create
  52. for the API specification. This API mimics the OpenAI Embedding API.
  53. """
  54. error_check_ret = await self._check_model(request)
  55. if error_check_ret is not None:
  56. return error_check_ret
  57. # Return error for unsupported features.
  58. if request.encoding_format == "base64":
  59. return self.create_error_response(
  60. "base64 encoding is not currently supported")
  61. if request.dimensions is not None:
  62. return self.create_error_response(
  63. "dimensions is currently not supported")
  64. model_name = request.model
  65. request_id = f"cmpl-{random_uuid()}"
  66. created_time = int(time.monotonic())
  67. # Schedule the request and get the result generator.
  68. generators = []
  69. try:
  70. prompt_is_tokens, prompts = parse_prompt_format(request.input)
  71. pooling_params = request.to_pooling_params()
  72. for i, prompt in enumerate(prompts):
  73. if prompt_is_tokens:
  74. prompt_formats = self._validate_prompt_and_tokenize(
  75. request, prompt_ids=prompt)
  76. else:
  77. prompt_formats = self._validate_prompt_and_tokenize(
  78. request, prompt=prompt)
  79. prompt_ids, prompt_text = prompt_formats
  80. generators.append(
  81. self.engine.generate(prompt_text,
  82. pooling_params,
  83. f"{request_id}-{i}",
  84. prompt_token_ids=prompt_ids))
  85. except ValueError as e:
  86. # TODO: Use a aphrodite-specific Validation Error
  87. return self.create_error_response(str(e))
  88. result_generator: AsyncIterator[Tuple[
  89. int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
  90. # Non-streaming response
  91. final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
  92. async for i, res in result_generator:
  93. if await raw_request.is_disconnected():
  94. # Abort the request if the client disconnects.
  95. await self.engine.abort(f"{request_id}-{i}")
  96. # TODO: Use a aphrodite-specific Validation Error
  97. return self.create_error_response("Client disconnected")
  98. final_res_batch[i] = res
  99. response = request_output_to_embedding_response(
  100. final_res_batch, request_id, created_time, model_name)
  101. return response