msgspec_utils.py 898 B

123456789101112131415161718192021222324
  1. from array import array
  2. from typing import Any, Type
  3. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  4. def encode_hook(obj: Any) -> Any:
  5. """Custom msgspec enc hook that supports array types.
  6. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
  7. """
  8. if isinstance(obj, array):
  9. assert obj.typecode == APHRODITE_TOKEN_ID_ARRAY_TYPE, (
  10. f"Aphrodite array type should use '{APHRODITE_TOKEN_ID_ARRAY_TYPE}'"
  11. f" type. Given array has a type code of {obj.typecode}.")
  12. return obj.tobytes()
  13. def decode_hook(type: Type, obj: Any) -> Any:
  14. """Custom msgspec dec hook that supports array types.
  15. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
  16. """
  17. if type is array:
  18. deserialized = array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  19. deserialized.frombytes(obj)
  20. return deserialized