base.h 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. /*
  2. * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
  3. * Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. namespace marlin_24 {
  19. constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
  20. // Instances of `Vec` are used to organize groups of >>registers<<, as needed
  21. // for instance as inputs to tensor core operations. Consequently, all
  22. // corresponding index accesses must be compile-time constants, which is why we
  23. // extensively use `#pragma unroll` throughout the kernel code to guarantee
  24. // this.
  25. template <typename T, int n>
  26. struct Vec {
  27. T elems[n];
  28. __device__ T& operator[](int i) { return elems[i]; }
  29. };
  30. template <int M_, int N_, int K_>
  31. struct ShapeBase {
  32. static constexpr int M = M_, N = N_, K = K_;
  33. };
  34. using I4 = Vec<int, 4>;
  35. // Matrix fragments for tensor core instructions; their precise layout is
  36. // documented here:
  37. // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
  38. using FragA = Vec<half2, 4>;
  39. using FragB = Vec<half2, 2>;
  40. using FragM = Vec<unsigned int, 1>;
  41. using FragC = Vec<float, 4>;
  42. using FragS = Vec<half2, 1>; // quantization scales
  43. } // namespace marlin_24