serving_embedding.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import base64
  2. import time
  3. from typing import AsyncIterator, List, Optional, Tuple
  4. import numpy as np
  5. from fastapi import Request
  6. from aphrodite.common.outputs import EmbeddingRequestOutput
  7. from aphrodite.common.utils import merge_async_iterators, random_uuid
  8. from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
  9. EmbeddingResponse,
  10. EmbeddingResponseData,
  11. UsageInfo)
  12. from aphrodite.endpoints.openai.serving_completions import parse_prompt_format
  13. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  14. OpenAIServing)
  15. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  16. TypeTokenIDs = List[int]
  17. def request_output_to_embedding_response(
  18. final_res_batch: List[EmbeddingRequestOutput], request_id: str,
  19. created_time: int, model_name: str,
  20. encoding_format: str) -> EmbeddingResponse:
  21. data = []
  22. num_prompt_tokens = 0
  23. for idx, final_res in enumerate(final_res_batch):
  24. assert final_res is not None
  25. prompt_token_ids = final_res.prompt_token_ids
  26. embedding = final_res.outputs.embedding
  27. if encoding_format == "base64":
  28. embedding = base64.b64encode(np.array(embedding))
  29. embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
  30. data.append(embedding_data)
  31. num_prompt_tokens += len(prompt_token_ids)
  32. usage = UsageInfo(
  33. prompt_tokens=num_prompt_tokens,
  34. total_tokens=num_prompt_tokens,
  35. )
  36. return EmbeddingResponse(
  37. id=request_id,
  38. created=created_time,
  39. model=model_name,
  40. data=data,
  41. usage=usage,
  42. )
  43. class OpenAIServingEmbedding(OpenAIServing):
  44. def __init__(self,
  45. engine: AsyncAphrodite,
  46. served_model_names: List[str],
  47. lora_modules: Optional[List[LoRAModulePath]] = None):
  48. super().__init__(engine=engine,
  49. served_model_names=served_model_names,
  50. lora_modules=lora_modules)
  51. async def create_embedding(self, request: EmbeddingRequest,
  52. raw_request: Request):
  53. """Completion API similar to OpenAI's API.
  54. See https://platform.openai.com/docs/api-reference/embeddings/create
  55. for the API specification. This API mimics the OpenAI Embedding API.
  56. """
  57. error_check_ret = await self._check_model(request)
  58. if error_check_ret is not None:
  59. return error_check_ret
  60. # Return error for unsupported features.
  61. encoding_format = (request.encoding_format
  62. if request.encoding_format else "float")
  63. if request.dimensions is not None:
  64. return self.create_error_response(
  65. "dimensions is currently not supported")
  66. model_name = request.model
  67. request_id = f"cmpl-{random_uuid()}"
  68. created_time = int(time.monotonic())
  69. # Schedule the request and get the result generator.
  70. generators = []
  71. try:
  72. prompt_is_tokens, prompts = parse_prompt_format(request.input)
  73. pooling_params = request.to_pooling_params()
  74. for i, prompt in enumerate(prompts):
  75. if prompt_is_tokens:
  76. prompt_formats = self._validate_prompt_and_tokenize(
  77. request, prompt_ids=prompt)
  78. else:
  79. prompt_formats = self._validate_prompt_and_tokenize(
  80. request, prompt=prompt)
  81. prompt_ids, prompt_text = prompt_formats
  82. generator = self.engine.encode(
  83. {
  84. "prompt": prompt_text,
  85. "prompt_token_ids": prompt_ids
  86. },
  87. pooling_params,
  88. f"{request_id}-{i}",
  89. )
  90. generators.append(generator)
  91. except ValueError as e:
  92. # TODO: Use a aphrodite-specific Validation Error
  93. return self.create_error_response(str(e))
  94. result_generator: AsyncIterator[Tuple[
  95. int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
  96. # Non-streaming response
  97. final_res_batch: List[Optional[EmbeddingRequestOutput]]
  98. final_res_batch = [None] * len(prompts)
  99. try:
  100. async for i, res in result_generator:
  101. if await raw_request.is_disconnected():
  102. # Abort the request if the client disconnects.
  103. await self.engine.abort(f"{request_id}-{i}")
  104. # TODO: Use a aphrodite-specific Validation Error
  105. return self.create_error_response("Client disconnected")
  106. final_res_batch[i] = res
  107. response = request_output_to_embedding_response(
  108. final_res_batch, request_id, created_time, model_name,
  109. encoding_format)
  110. except ValueError as e:
  111. # TODO: Use a aphrodite-specific Validation Error
  112. return self.create_error_response(str(e))
  113. return response