hip_compat.cuh 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. // Adapted from turboderp exllama: https://github.com/turboderp/exllama
  2. #ifndef _hip_compat_cuh
  3. #define _hip_compat_cuh
  4. // Workaround for a bug in hipamd, backported from upstream.
  5. __device__ __forceinline__ __half __compat_hrcp(__half x) {
  6. return __half_raw{
  7. static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
  8. }
  9. __device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
  10. return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
  11. static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
  12. }
  13. #define hrcp __compat_hrcp
  14. #define h2rcp __compat_h2rcp
  15. // Workaround for hipify_python using rocblas instead of hipblas.
  16. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
  17. hipblasOperation_t transA,
  18. hipblasOperation_t transB,
  19. int m,
  20. int n,
  21. int k,
  22. const half* alpha,
  23. const half* AP,
  24. int lda,
  25. const half* BP,
  26. int ldb,
  27. const half* beta,
  28. half* CP,
  29. int ldc) {
  30. return hipblasHgemm(handle, transA, transB, m, n, k,
  31. reinterpret_cast<const hipblasHalf *>(alpha),
  32. reinterpret_cast<const hipblasHalf *>(AP), lda,
  33. reinterpret_cast<const hipblasHalf *>(BP), ldb,
  34. reinterpret_cast<const hipblasHalf *>(beta),
  35. reinterpret_cast<hipblasHalf *>(CP), ldc);
  36. }
  37. #define rocblas_handle hipblasHandle_t
  38. #define rocblas_operation_none HIPBLAS_OP_N
  39. #define rocblas_get_stream hipblasGetStream
  40. #define rocblas_set_stream hipblasSetStream
  41. #define rocblas_hgemm __compat_hipblasHgemm
  42. #endif