1
0

serving_embedding.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import asyncio
  2. import base64
  3. import time
  4. from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
  5. Union, cast)
  6. import numpy as np
  7. from fastapi import Request
  8. from loguru import logger
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.outputs import EmbeddingRequestOutput
  11. from aphrodite.common.utils import merge_async_iterators, random_uuid
  12. from aphrodite.endpoints.logger import RequestLogger
  13. from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
  14. EmbeddingResponse,
  15. EmbeddingResponseData,
  16. ErrorResponse, UsageInfo)
  17. from aphrodite.endpoints.openai.serving_engine import (BaseModelPath,
  18. OpenAIServing)
  19. from aphrodite.engine.protocol import EngineClient
  20. TypeTokenIDs = List[int]
  21. def request_output_to_embedding_response(
  22. final_res_batch: List[EmbeddingRequestOutput], request_id: str,
  23. created_time: int, model_name: str,
  24. encoding_format: str) -> EmbeddingResponse:
  25. data: List[EmbeddingResponseData] = []
  26. num_prompt_tokens = 0
  27. for idx, final_res in enumerate(final_res_batch):
  28. prompt_token_ids = final_res.prompt_token_ids
  29. embedding = final_res.outputs.embedding
  30. if encoding_format == "base64":
  31. # Force to use float32 for base64 encoding
  32. # to match the OpenAI python client behavior
  33. embedding_bytes = np.array(embedding, dtype="float32").tobytes()
  34. embedding = base64.b64encode(embedding_bytes).decode("utf-8")
  35. embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
  36. data.append(embedding_data)
  37. num_prompt_tokens += len(prompt_token_ids)
  38. usage = UsageInfo(
  39. prompt_tokens=num_prompt_tokens,
  40. total_tokens=num_prompt_tokens,
  41. )
  42. return EmbeddingResponse(
  43. id=request_id,
  44. created=created_time,
  45. model=model_name,
  46. data=data,
  47. usage=usage,
  48. )
  49. class OpenAIServingEmbedding(OpenAIServing):
  50. def __init__(
  51. self,
  52. engine_client: EngineClient,
  53. model_config: ModelConfig,
  54. base_model_paths: List[BaseModelPath],
  55. *,
  56. request_logger: Optional[RequestLogger],
  57. ):
  58. super().__init__(engine_client=engine_client,
  59. model_config=model_config,
  60. base_model_paths=base_model_paths,
  61. lora_modules=None,
  62. prompt_adapters=None,
  63. request_logger=request_logger)
  64. self._enabled = self._check_embedding_mode(model_config.embedding_mode)
  65. async def create_embedding(
  66. self,
  67. request: EmbeddingRequest,
  68. raw_request: Optional[Request] = None
  69. ) -> Union[ErrorResponse, EmbeddingResponse]:
  70. """Completion API similar to OpenAI's API.
  71. See https://platform.openai.com/docs/api-reference/embeddings/create
  72. for the API specification. This API mimics the OpenAI Embedding API.
  73. """
  74. if not self._enabled:
  75. return self.create_error_response("Embedding API disabled")
  76. error_check_ret = await self._check_model(request)
  77. if error_check_ret is not None:
  78. return error_check_ret
  79. encoding_format = (request.encoding_format
  80. if request.encoding_format else "float")
  81. if request.dimensions is not None:
  82. return self.create_error_response(
  83. "dimensions is currently not supported")
  84. model_name = request.model
  85. request_id = f"embd-{random_uuid()}"
  86. created_time = int(time.monotonic())
  87. # Schedule the request and get the result generator.
  88. generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
  89. try:
  90. (
  91. lora_request,
  92. prompt_adapter_request,
  93. ) = self._maybe_get_adapters(request)
  94. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  95. pooling_params = request.to_pooling_params()
  96. prompts = list(
  97. self._tokenize_prompt_input_or_inputs(
  98. request,
  99. tokenizer,
  100. request.input,
  101. ))
  102. for i, prompt_inputs in enumerate(prompts):
  103. request_id_item = f"{request_id}-{i}"
  104. self._log_inputs(request_id_item,
  105. prompt_inputs,
  106. params=pooling_params,
  107. lora_request=lora_request,
  108. prompt_adapter_request=prompt_adapter_request)
  109. if prompt_adapter_request is not None:
  110. raise NotImplementedError(
  111. "Prompt adapter is not supported "
  112. "for embedding models")
  113. generator = self.engine_client.encode(
  114. {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
  115. pooling_params,
  116. request_id_item,
  117. lora_request=lora_request,
  118. )
  119. generators.append(generator)
  120. except ValueError as e:
  121. # TODO: Use an aphrodite-specific Validation Error
  122. return self.create_error_response(str(e))
  123. result_generator: AsyncIterator[Tuple[
  124. int, EmbeddingRequestOutput]] = merge_async_iterators(
  125. *generators,
  126. is_cancelled=raw_request.is_disconnected
  127. if raw_request else None)
  128. # Non-streaming response
  129. final_res_batch: List[Optional[EmbeddingRequestOutput]]
  130. final_res_batch = [None] * len(prompts)
  131. try:
  132. async for i, res in result_generator:
  133. final_res_batch[i] = res
  134. for final_res in final_res_batch:
  135. assert final_res is not None
  136. final_res_batch_checked = cast(List[EmbeddingRequestOutput],
  137. final_res_batch)
  138. response = request_output_to_embedding_response(
  139. final_res_batch_checked, request_id, created_time, model_name,
  140. encoding_format)
  141. except asyncio.CancelledError:
  142. return self.create_error_response("Client disconnected")
  143. except ValueError as e:
  144. # TODO: Use an aphrodite-specific Validation Error
  145. return self.create_error_response(str(e))
  146. return response
  147. def _check_embedding_mode(self, embedding_mode: bool):
  148. if not embedding_mode:
  149. logger.warning(
  150. "embedding_mode is False. Embedding API will not work.")
  151. else:
  152. logger.info("Activating the server engine with embedding enabled.")
  153. return embedding_mode