torch_utils.h 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #pragma once
  2. #include "torch/csrc/cuda/Stream.h"
  3. #include "torch/extension.h"
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <cstdio>
  6. #include <cuda_fp16.h>
  7. #include <cuda_runtime.h>
  8. #include <iostream>
  9. #include <nvToolsExt.h>
  10. #include <torch/custom_class.h>
  11. #include <torch/script.h>
  12. #include <vector>
  13. #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
  14. #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
  15. #define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
  16. #define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
  17. #define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  18. #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
  19. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  20. #define CHECK_INPUT(x, st) \
  21. CHECK_TH_CUDA(x); \
  22. CHECK_CONTIGUOUS(x); \
  23. CHECK_TYPE(x, st)
  24. #define CHECK_CPU_INPUT(x, st) \
  25. CHECK_CPU(x); \
  26. CHECK_CONTIGUOUS(x); \
  27. CHECK_TYPE(x, st)
  28. #define CHECK_OPTIONAL_INPUT(x, st) \
  29. if (x.has_value()) { \
  30. CHECK_INPUT(x.value(), st); \
  31. }
  32. #define CHECK_OPTIONAL_CPU_INPUT(x, st) \
  33. if (x.has_value()) { \
  34. CHECK_CPU_INPUT(x.value(), st); \
  35. }
  36. #define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl
  37. #define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl
  38. namespace fastertransformer {
  39. template<typename T>
  40. inline T* get_ptr(torch::Tensor& t)
  41. {
  42. return reinterpret_cast<T*>(t.data_ptr());
  43. }
  44. std::vector<size_t> convert_shape(torch::Tensor tensor);
  45. size_t sizeBytes(torch::Tensor tensor);
  46. QuantType get_ft_quant_type(torch::ScalarType quant_type)
  47. {
  48. if (quant_type == torch::kInt8) {
  49. return QuantType::INT8_WEIGHT_ONLY;
  50. }
  51. else if (quant_type == at::ScalarType::QUInt4x2) {
  52. return QuantType::PACKED_INT4_WEIGHT_ONLY;
  53. }
  54. else {
  55. TORCH_CHECK(false, "Invalid quantization type");
  56. }
  57. }
  58. } // namespace fastertransformer