test_embedding.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import base64
  2. import numpy as np
  3. import openai
  4. import pytest
  5. from ...utils import RemoteOpenAIServer
  6. EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
  7. @pytest.fixture(scope="module")
  8. def embedding_server():
  9. args = [
  10. # use half precision for speed and memory savings in CI environment
  11. "--dtype",
  12. "bfloat16",
  13. "--enforce-eager",
  14. "--max-model-len",
  15. "8192",
  16. ]
  17. with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
  18. yield remote_server
  19. @pytest.mark.asyncio
  20. @pytest.fixture(scope="module")
  21. def embedding_client(embedding_server):
  22. return embedding_server.get_async_client()
  23. @pytest.mark.asyncio
  24. @pytest.mark.parametrize(
  25. "model_name",
  26. [EMBEDDING_MODEL_NAME],
  27. )
  28. async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
  29. model_name: str):
  30. input_texts = [
  31. "The chef prepared a delicious meal.",
  32. ]
  33. # test single embedding
  34. embeddings = await embedding_client.embeddings.create(
  35. model=model_name,
  36. input=input_texts,
  37. encoding_format="float",
  38. )
  39. assert embeddings.id is not None
  40. assert len(embeddings.data) == 1
  41. assert len(embeddings.data[0].embedding) == 4096
  42. assert embeddings.usage.completion_tokens == 0
  43. assert embeddings.usage.prompt_tokens == 9
  44. assert embeddings.usage.total_tokens == 9
  45. # test using token IDs
  46. input_tokens = [1, 1, 1, 1, 1]
  47. embeddings = await embedding_client.embeddings.create(
  48. model=model_name,
  49. input=input_tokens,
  50. encoding_format="float",
  51. )
  52. assert embeddings.id is not None
  53. assert len(embeddings.data) == 1
  54. assert len(embeddings.data[0].embedding) == 4096
  55. assert embeddings.usage.completion_tokens == 0
  56. assert embeddings.usage.prompt_tokens == 5
  57. assert embeddings.usage.total_tokens == 5
  58. @pytest.mark.asyncio
  59. @pytest.mark.parametrize(
  60. "model_name",
  61. [EMBEDDING_MODEL_NAME],
  62. )
  63. async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
  64. model_name: str):
  65. # test List[str]
  66. input_texts = [
  67. "The cat sat on the mat.", "A feline was resting on a rug.",
  68. "Stars twinkle brightly in the night sky."
  69. ]
  70. embeddings = await embedding_client.embeddings.create(
  71. model=model_name,
  72. input=input_texts,
  73. encoding_format="float",
  74. )
  75. assert embeddings.id is not None
  76. assert len(embeddings.data) == 3
  77. assert len(embeddings.data[0].embedding) == 4096
  78. # test List[List[int]]
  79. input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
  80. [25, 32, 64, 77]]
  81. embeddings = await embedding_client.embeddings.create(
  82. model=model_name,
  83. input=input_tokens,
  84. encoding_format="float",
  85. )
  86. assert embeddings.id is not None
  87. assert len(embeddings.data) == 4
  88. assert len(embeddings.data[0].embedding) == 4096
  89. assert embeddings.usage.completion_tokens == 0
  90. assert embeddings.usage.prompt_tokens == 17
  91. assert embeddings.usage.total_tokens == 17
  92. @pytest.mark.asyncio
  93. @pytest.mark.parametrize(
  94. "model_name",
  95. [EMBEDDING_MODEL_NAME],
  96. )
  97. async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
  98. model_name: str):
  99. input_texts = [
  100. "Hello my name is",
  101. "The best thing about Aphrodite is that it supports many different models" # noqa: E501
  102. ]
  103. responses_float = await embedding_client.embeddings.create(
  104. input=input_texts, model=model_name, encoding_format="float")
  105. responses_base64 = await embedding_client.embeddings.create(
  106. input=input_texts, model=model_name, encoding_format="base64")
  107. decoded_responses_base64_data = []
  108. for data in responses_base64.data:
  109. decoded_responses_base64_data.append(
  110. np.frombuffer(base64.b64decode(data.embedding),
  111. dtype="float").tolist())
  112. assert responses_float.data[0].embedding == decoded_responses_base64_data[
  113. 0]
  114. assert responses_float.data[1].embedding == decoded_responses_base64_data[
  115. 1]