dtype_fp8.cuh 616 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #pragma once
  2. #include "attention_generic.cuh"
  3. #include <stdint.h>
  4. #ifdef ENABLE_FP8
  5. #ifndef USE_ROCM
  6. #include <cuda_fp8.h>
  7. #endif // USE_ROCM
  8. #endif // ENABLE_FP8
  9. namespace aphrodite {
  10. enum class Fp8KVCacheDataType {
  11. kAuto = 0,
  12. kFp8E4M3 = 1,
  13. kFp8E5M2 = 2,
  14. };
  15. // fp8 vector types for quantization of kv cache
  16. template <>
  17. struct Vec<uint8_t, 1> {
  18. using Type = uint8_t;
  19. };
  20. template <>
  21. struct Vec<uint8_t, 2> {
  22. using Type = uint16_t;
  23. };
  24. template <>
  25. struct Vec<uint8_t, 4> {
  26. using Type = uint32_t;
  27. };
  28. template <>
  29. struct Vec<uint8_t, 8> {
  30. using Type = uint2;
  31. };
  32. } // namespace aphrodite