hip_float8.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #pragma once
  2. #ifdef __HIPCC__
  3. #include <hip/hip_runtime.h>
  4. #else
  5. #include <type_traits>
  6. #include <stdint.h>
  7. #include <math.h>
  8. #include <iostream>
  9. #endif
  10. #include "hip_float8_impl.h"
  11. struct alignas(1) hip_fp8
  12. {
  13. struct from_bits_t
  14. {
  15. };
  16. HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
  17. uint8_t data;
  18. hip_fp8() = default;
  19. HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
  20. HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
  21. explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
  22. : data(v)
  23. {
  24. }
  25. #ifdef __HIP__MI300__
  26. // NOTE: ON-DEVICE... always optimal bias
  27. explicit HIP_FP8_DEVICE hip_fp8(float v)
  28. : data(hip_fp8_impl::to_fp8_from_fp32(v))
  29. {
  30. }
  31. explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
  32. : hip_fp8(static_cast<float>(v))
  33. {
  34. }
  35. // Host only implementation using s/w simulation
  36. explicit HIP_FP8_HOST
  37. #else // __HIP__MI300__
  38. // both Host and DEVICE for non-MI300 using s/w simulation
  39. explicit HIP_FP8_HOST_DEVICE
  40. #endif // __HIP__MI300__
  41. hip_fp8(float v)
  42. {
  43. data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
  44. }
  45. explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
  46. : hip_fp8(static_cast<float>(v))
  47. {
  48. }
  49. #ifdef __HIP__MI300__
  50. // upcast using device specific intrinsic
  51. explicit inline HIP_FP8_DEVICE operator float() const
  52. {
  53. float fval;
  54. uint32_t i32val = static_cast<uint32_t>(data);
  55. // upcast
  56. asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
  57. return fval;
  58. }
  59. explicit inline HIP_FP8_HOST operator float() const
  60. #else // __HIP__MI300__
  61. explicit inline HIP_FP8_HOST_DEVICE operator float() const
  62. #endif // __HIP__MI300__
  63. {
  64. return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
  65. }
  66. };
  67. namespace std
  68. {
  69. inline hip_fp8 sin(hip_fp8 a)
  70. {
  71. return hip_fp8(sinf(float(a)));
  72. }
  73. inline hip_fp8 cos(hip_fp8 a)
  74. {
  75. return hip_fp8(cosf(float(a)));
  76. }
  77. HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
  78. {
  79. return a;
  80. }
  81. } // namespace std
  82. // Special operator overloading
  83. inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
  84. {
  85. return os << float(f8);
  86. }
  87. // all + operator overloading with mixed types
  88. // mixed types, always converts to f32, does computation in f32, and returns float
  89. inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
  90. {
  91. return (fa + float(b));
  92. }
  93. inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
  94. {
  95. return (float(a) + fb);
  96. }
  97. inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
  98. {
  99. return hip_fp8(float(a) + float(b));
  100. }
  101. inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
  102. {
  103. return a = hip_fp8(float(a) + float(b));
  104. }
  105. // overloading multiplication, always returns float,
  106. inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
  107. {
  108. return float(a) * float(b);
  109. }
  110. inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
  111. {
  112. return (a * float(b));
  113. }
  114. inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
  115. {
  116. return (float(a) * b);
  117. }
  118. inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
  119. {
  120. return ((float)a * float(b));
  121. }
  122. inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
  123. {
  124. return ((float)a * float(b));
  125. }
  126. // overloading for compare
  127. inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
  128. {
  129. return (a.data == b.data);
  130. }
  131. inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
  132. {
  133. return (a.data != b.data);
  134. }
  135. inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
  136. {
  137. return static_cast<float>(a) >= static_cast<float>(b);
  138. }
  139. inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
  140. {
  141. return static_cast<float>(a) > static_cast<float>(b);
  142. }