embeddings.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """
  2. This file is derived from
  3. [text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
  4. and modified.
  5. The changes introduced are: Suppression of progress bar,
  6. typing/pydantic classes moved into this file,
  7. embeddings function declared async.
  8. """
  9. import os
  10. import base64
  11. import numpy as np
  12. from transformers import AutoModel
  13. embeddings_params_initialized = False
  14. def initialize_embedding_params():
  15. '''
  16. using 'lazy loading' to avoid circular import
  17. so this function will be executed only once
  18. '''
  19. global embeddings_params_initialized
  20. if not embeddings_params_initialized:
  21. global st_model, embeddings_model, embeddings_device
  22. st_model = os.environ.get("OPENAI_EMBEDDING_MODEL",
  23. 'all-mpnet-base-v2')
  24. embeddings_model = None
  25. # OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
  26. # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
  27. # hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
  28. # hpu, mtia, privateuseone
  29. embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu')
  30. if embeddings_device.lower() == 'auto':
  31. embeddings_device = None
  32. embeddings_params_initialized = True
  33. def load_embedding_model(model: str):
  34. try:
  35. from sentence_transformers import SentenceTransformer
  36. except ModuleNotFoundError:
  37. print("The sentence_transformers module has not been found. " +
  38. "Please install it manually with " +
  39. "pip install -U sentence-transformers.")
  40. raise ModuleNotFoundError from None
  41. initialize_embedding_params()
  42. global embeddings_device, embeddings_model
  43. try:
  44. print(f"Try embedding model: {model} on {embeddings_device}")
  45. if 'jina-embeddings' in model:
  46. # trust_remote_code is needed to use the encode method
  47. embeddings_model = AutoModel.from_pretrained(
  48. model, trust_remote_code=True)
  49. embeddings_model = embeddings_model.to(embeddings_device)
  50. else:
  51. embeddings_model = SentenceTransformer(
  52. model,
  53. device=embeddings_device,
  54. )
  55. print(f"Loaded embedding model: {model}")
  56. except Exception as e:
  57. embeddings_model = None
  58. raise Exception(f"Error: Failed to load embedding model: {model}",
  59. internal_message=repr(e)) from None
  60. def get_embeddings_model():
  61. initialize_embedding_params()
  62. global embeddings_model, st_model
  63. if st_model and not embeddings_model:
  64. load_embedding_model(st_model) # lazy load the model
  65. return embeddings_model
  66. def get_embeddings_model_name() -> str:
  67. initialize_embedding_params()
  68. global st_model
  69. return st_model
  70. def get_embeddings(input: list) -> np.ndarray:
  71. model = get_embeddings_model()
  72. embedding = model.encode(input,
  73. convert_to_numpy=True,
  74. normalize_embeddings=True,
  75. convert_to_tensor=False,
  76. show_progress_bar=False)
  77. return embedding
  78. async def embeddings(input: list,
  79. encoding_format: str,
  80. model: str = None) -> dict:
  81. if model is None:
  82. model = st_model
  83. else:
  84. load_embedding_model(model)
  85. embeddings = get_embeddings(input)
  86. if encoding_format == "base64":
  87. data = [{
  88. "object": "embedding",
  89. "embedding": float_list_to_base64(emb),
  90. "index": n
  91. } for n, emb in enumerate(embeddings)]
  92. else:
  93. data = [{
  94. "object": "embedding",
  95. "embedding": emb.tolist(),
  96. "index": n
  97. } for n, emb in enumerate(embeddings)]
  98. response = {
  99. "object": "list",
  100. "data": data,
  101. "model": st_model if model is None else model,
  102. "usage": {
  103. "prompt_tokens": 0,
  104. "total_tokens": 0,
  105. }
  106. }
  107. return response
  108. def float_list_to_base64(float_array: np.ndarray) -> str:
  109. # Convert the list to a float32 array that the OpenAPI client expects
  110. # float_array = np.array(float_list, dtype="float32")
  111. # Get raw bytes
  112. bytes_array = float_array.tobytes()
  113. # Encode bytes into base64
  114. encoded_bytes = base64.b64encode(bytes_array)
  115. # Turn raw base64 encoded bytes into ASCII
  116. ascii_string = encoded_bytes.decode('ascii')
  117. return ascii_string