1
0

cta_iterator.h 19 KB


  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cstdint>
  19. #include "common.h"
  20. namespace aphrodite {
  21. namespace autoquant {
  22. #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
  23. #define L2_CACHEHINT(size) ".L2::" #size "B"
  24. #else
  25. #define L2_CACHEHINT(size)
  26. #endif
  27. template<typename T>
  28. __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
  29. {
  30. #if APHRODITE_ARCH_SM80
  31. constexpr int cp_size = sizeof(T);
  32. static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
  33. // clang-format off
  34. asm volatile("{\n"
  35. " .reg .pred p;\n"
  36. " setp.ne.b32 p, %0, 0;\n"
  37. " @p cp.async.cg.shared.global" L2_CACHEHINT(256) " [%1], [%2], %3;\n"
  38. "}\n" ::"r"((int)mask),
  39. "r"(smem_int_ptr),
  40. "l"(src),
  41. "n"(cp_size));
  42. // clang-format on
  43. #else
  44. assert(APHRODITE_ARCH_SM80);
  45. #endif
  46. }
  47. template<typename T>
  48. __inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
  49. {
  50. #if APHRODITE_ARCH_SM80
  51. constexpr int cp_size = sizeof(T);
  52. static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
  53. // clang-format off
  54. asm volatile("{\n"
  55. " .reg .pred p;\n"
  56. " setp.ne.b32 p, %0, 0;\n"
  57. " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
  58. "}\n" ::"r"((int)mask),
  59. "r"(smem_int_ptr),
  60. "l"(src),
  61. "n"(cp_size));
  62. // clang-format on
  63. #else
  64. assert(APHRODITE_ARCH_SM80);
  65. #endif
  66. }
  67. template<typename T>
  68. __inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
  69. {
  70. #if APHRODITE_ARCH_SM80
  71. constexpr int cp_size = sizeof(T);
  72. // clang-format off
  73. asm volatile("{\n"
  74. " .reg .pred p;\n"
  75. " setp.ne.b32 p, %0, 0;\n"
  76. " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
  77. "}\n" ::"r"((int)mask),
  78. "r"(smem_int_ptr),
  79. "l"(src),
  80. "n"(cp_size));
  81. // clang-format on
  82. #else
  83. assert(APHRODITE_ARCH_SM80);
  84. #endif
  85. }
  86. template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
  87. struct IteratorA {
  88. static constexpr int SLICE_K = CTA_K / SLICES;
  89. using AccessType = uint4;
  90. static constexpr int kAccessSize = sizeof(AccessType);
  91. static_assert(CTA_M % 32 == 0 && CTA_K % 32 == 0, "A is pre-formatted as 32x32 tiles");
  92. // A is [K/32, M/32, WARP_SIZE] uint4
  93. static constexpr int kShapeM = CTA_M;
  94. static constexpr int kShapeK = SLICE_K / 32;
  95. // thread access shape
  96. static constexpr int kAccessM = 1;
  97. static constexpr int kAccessK = 1;
  98. // warp thread arrangement
  99. static constexpr int kWarpThreadC = 32;
  100. static constexpr int kWarpThreadS = 1;
  101. // warp shape per access
  102. static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
  103. static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // 1
  104. // warp access iterations
  105. static constexpr int kWarpIterM = kShapeM / kWarpAccessM;
  106. static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
  107. // warp arrangement
  108. static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
  109. static constexpr int kWarpK = WARPS > kWarpIterM ? (WARPS / kWarpM) : 1;
  110. // iterations
  111. static constexpr int kIterM = kWarpIterM / kWarpM;
  112. static constexpr int kIterK = kWarpIterK / kWarpK;
  113. static constexpr int kIterCount = kIterM * kIterK;
  114. static_assert(kIterCount > 0);
  115. // warp footprint
  116. static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
  117. static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
  118. static constexpr int kSizePerStage = kShapeK * kShapeM;
  119. static constexpr int kSmemByteSize = kAccessSize * STAGES * kSizePerStage;
  120. const uint* src_;
  121. AccessType* smem_;
  122. uint32_t smem_int_ptr_;
  123. const int m_;
  124. const int k_;
  125. const int warp_id_;
  126. const int lane_id_;
  127. int src_offset_;
  128. int dst_offset_;
  129. int src_step_m_;
  130. int src_step_k_;
  131. int src_step_s_;
  132. int dst_step_m_;
  133. int dst_step_k_;
  134. int dst_step_s_;
  135. int iter_m_{0};
  136. IteratorA() = default;
  137. __device__ IteratorA(const uint* src, void* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
  138. src_(src),
  139. smem_((AccessType*)smem),
  140. smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
  141. m_(m),
  142. k_(k),
  143. warp_id_(warp_id),
  144. lane_id_(lane_id)
  145. {
  146. const int warp_offset_m = warp_id_ % kWarpM;
  147. const int warp_offset_k = warp_id_ / kWarpM;
  148. const int warp_thread_offset_m = lane_id_ % kWarpThreadC;
  149. const int warp_thread_offset_k = lane_id_ / kWarpThreadC;
  150. const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
  151. const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
  152. const int src_offset_m = cta_thread_offset_m + cta_m;
  153. const int src_offset_k = cta_thread_offset_k + cta_k / 32;
  154. src_offset_ = src_offset_k * m_ + src_offset_m;
  155. src_step_m_ = kWarpAccessM;
  156. src_step_k_ = kWarpAccessK * m_ - kIterM * kWarpAccessM;
  157. src_step_s_ = CTA_K / 32 * m_ - kIterK * kWarpAccessK * m_;
  158. const int dst_offset_m = cta_thread_offset_m;
  159. const int dst_offset_k = cta_thread_offset_k;
  160. dst_offset_ = dst_offset_k * kShapeM + dst_offset_m;
  161. dst_step_m_ = kWarpAccessM;
  162. dst_step_k_ = kWarpAccessK * kShapeM - kIterM * kWarpAccessM;
  163. dst_step_s_ = SLICE_K / 32 * kShapeM - kIterK * kWarpAccessK * kShapeM;
  164. dst_offset_ *= kAccessSize;
  165. dst_step_m_ *= kAccessSize;
  166. dst_step_k_ *= kAccessSize;
  167. dst_step_s_ *= kAccessSize;
  168. }
  169. __device__ void prefetch_stage(bool mask)
  170. {
  171. PRAGMA_UNROLL
  172. for (int i = 0; i < kIterCount; ++i) {
  173. prefetch(mask);
  174. ++(*this);
  175. }
  176. next_stage();
  177. }
  178. __device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
  179. {
  180. PRAGMA_UNROLL
  181. for (int i = 0; i < batch_size; ++i) {
  182. if (batch_idx * batch_size + i < kIterCount) {
  183. prefetch(mask);
  184. ++(*this);
  185. }
  186. }
  187. }
  188. __device__ IteratorA& operator++()
  189. {
  190. src_offset_ += src_step_m_;
  191. dst_offset_ += dst_step_m_;
  192. ++iter_m_;
  193. if (iter_m_ < kIterM) {
  194. return *this;
  195. }
  196. iter_m_ = 0;
  197. src_offset_ += src_step_k_;
  198. dst_offset_ += dst_step_k_;
  199. return *this;
  200. }
  201. __device__ void next_stage()
  202. {
  203. src_offset_ += src_step_s_;
  204. dst_offset_ += dst_step_s_;
  205. if (dst_offset_ >= kSmemByteSize) {
  206. dst_offset_ -= kSmemByteSize;
  207. }
  208. }
  209. __device__ void prefetch(bool mask)
  210. {
  211. cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
  212. }
  213. };
  214. template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES, int GROUP_SIZE, typename T_Q>
  215. struct IteratorQ {
  216. static constexpr int SLICE_K = CTA_K / SLICES;
  217. using AccessType = uint;
  218. static constexpr int kAccessSize = sizeof(AccessType);
  219. static constexpr int kAccessM = kAccessSize / sizeof(T_Q);
  220. static constexpr int kAccessK = GROUP_SIZE;
  221. // warp thread arrangement
  222. static constexpr int kWarpThreadC = 32;
  223. static constexpr int kWarpThreadS = 1;
  224. // warp shape per access
  225. static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
  226. static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // GROUP_SIZE
  227. // warp access iterations
  228. static constexpr int kWarpIterM = CTA_M / kWarpAccessM; // CTA_M / 32
  229. static constexpr int kWarpIterK = SLICE_K / kWarpAccessK; // SLICE_K / GROUP_SIZE, maybe 0
  230. // kWarpIterK == 0 => SLICE_K < kWarpAccessK => kIterK == 1
  231. // warp arrangement
  232. static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
  233. static constexpr int kWarpK = WARPS > kWarpIterM ? WARPS / kWarpM : 1;
  234. // iterations
  235. static constexpr int kIterM = kWarpIterM / kWarpM;
  236. static constexpr int kIterK = kWarpIterK >= kWarpK ? kWarpIterK / kWarpK : 1;
  237. static constexpr int kIterCount = kIterM * kIterK;
  238. // warp footprint
  239. static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
  240. static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
  241. static constexpr int kSizePerStage = std::max(SLICE_K / GROUP_SIZE, 1) * CTA_M;
  242. static constexpr int kSmemByteSize = sizeof(uint) * STAGES * kSizePerStage;
  243. const T_Q* const src_;
  244. T_Q* const smem_;
  245. uint32_t const smem_int_ptr_;
  246. const int m_;
  247. const int k_;
  248. bool is_out_of_bound_; // mask for out-of-bound warps
  249. int src_offset_k_;
  250. int src_offset_m_;
  251. int src_offset_;
  252. int src_step_m_;
  253. int src_step_k_;
  254. int dst_offset_;
  255. int dst_step_m_;
  256. int dst_step_k_;
  257. int tmp_src_offset_;
  258. int tmp_dst_offset_;
  259. int iter_m_{0};
  260. struct Storage {
  261. T_Q data[SLICES][STAGES * kSizePerStage];
  262. };
  263. IteratorQ() = default;
  264. __device__ IteratorQ(const T_Q* src, T_Q* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
  265. src_(src), smem_(smem), smem_int_ptr_(cast_smem_ptr_to_uint(smem)), m_(m), k_(k)
  266. {
  267. const int warp_offset_m = warp_id % kWarpM;
  268. const int warp_offset_k = warp_id / kWarpM;
  269. const int warp_thread_offset_m = lane_id % kWarpThreadC;
  270. const int warp_thread_offset_k = lane_id / kWarpThreadC;
  271. const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
  272. const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
  273. // mask out-of-bound warps
  274. is_out_of_bound_ = cta_thread_offset_k >= SLICE_K;
  275. src_offset_m_ = cta_thread_offset_m + cta_m;
  276. src_offset_k_ = cta_thread_offset_k + cta_k;
  277. src_offset_ = src_offset_k_ / GROUP_SIZE * m_ + src_offset_m_;
  278. src_step_m_ = kWarpAccessM;
  279. src_step_k_ = m_ - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
  280. const int dst_offset_m = cta_thread_offset_m;
  281. const int dst_offset_k = cta_thread_offset_k;
  282. dst_offset_ = dst_offset_k / GROUP_SIZE * CTA_M + dst_offset_m;
  283. dst_step_m_ = kWarpAccessM;
  284. dst_step_k_ = CTA_M - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
  285. dst_offset_ *= kAccessSize;
  286. dst_step_m_ *= kAccessSize;
  287. dst_step_k_ *= kAccessSize;
  288. tmp_src_offset_ = src_offset_;
  289. tmp_dst_offset_ = dst_offset_;
  290. }
  291. __device__ void prefetch_stage(bool mask)
  292. {
  293. if (is_out_of_bound_) {
  294. return;
  295. }
  296. PRAGMA_UNROLL
  297. for (int i = 0; i < kIterCount; ++i) {
  298. prefetch(mask);
  299. ++(*this);
  300. }
  301. next_stage();
  302. }
  303. __device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
  304. {
  305. if (is_out_of_bound_) {
  306. return;
  307. }
  308. PRAGMA_UNROLL
  309. for (int i = 0; i < batch_size; ++i) {
  310. if (batch_idx * batch_size + i < kIterCount) {
  311. prefetch(mask);
  312. ++(*this);
  313. }
  314. }
  315. }
  316. __device__ IteratorQ& operator++()
  317. {
  318. ++iter_m_;
  319. src_offset_ += src_step_m_;
  320. dst_offset_ += dst_step_m_;
  321. if (iter_m_ < kIterM) {
  322. return *this;
  323. }
  324. iter_m_ = 0;
  325. if constexpr (SLICE_K >= GROUP_SIZE) {
  326. src_offset_ += src_step_k_;
  327. dst_offset_ += dst_step_k_;
  328. }
  329. // else advnace offsets in `next_stage`
  330. return *this;
  331. }
  332. __device__ void next_stage()
  333. {
  334. if constexpr (SLICE_K >= GROUP_SIZE) {
  335. src_offset_ += (CTA_K / GROUP_SIZE - kIterK) * m_;
  336. dst_offset_ += kAccessSize * (SLICE_K / GROUP_SIZE - kIterK) * CTA_M;
  337. }
  338. else { // SLICE_K < GROUP_SIZE, recompute `src_offset_`
  339. src_offset_k_ += CTA_K;
  340. src_offset_ = (src_offset_k_ / GROUP_SIZE) * m_ + src_offset_m_;
  341. dst_offset_ += dst_step_k_;
  342. }
  343. if (dst_offset_ >= kSmemByteSize) {
  344. dst_offset_ -= kSmemByteSize;
  345. }
  346. }
  347. __device__ void prefetch(bool mask)
  348. {
  349. cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
  350. }
  351. };
  352. template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES, typename T_BC>
  353. struct IteratorB {
  354. static constexpr int SLICE_K = CTA_K / SLICES;
  355. static constexpr int kElementSize = sizeof(T_BC);
  356. using AccessType = uint4;
  357. static constexpr int kAccessSize = sizeof(AccessType);
  358. static constexpr int kShapeK = SLICE_K;
  359. static constexpr int kShapeN = CTA_N;
  360. static constexpr int kAccessK = kAccessSize / sizeof(T_BC);
  361. static_assert(kShapeK % kAccessSize == 0);
  362. // warp thread arrangement
  363. static constexpr int kWarpThreadC = std::max(kShapeK / kAccessK, 1);
  364. static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
  365. // warp shape per access
  366. static constexpr int kWarpAccessK = kWarpThreadC * kAccessK;
  367. static constexpr int kWarpAccessN = kWarpThreadS;
  368. // warp access iterations
  369. static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
  370. static constexpr int kWarpIterN = kShapeN / kWarpAccessN;
  371. // warp arrangement
  372. static constexpr int kWarpK = kWarpIterK >= WARPS ? WARPS : kWarpIterK;
  373. static constexpr int kWarpN = WARPS > kWarpIterK ? WARPS / kWarpK : 1;
  374. // iterations
  375. static constexpr int kIterK = kWarpIterK / kWarpK;
  376. static constexpr int kIterN = kWarpIterN >= kWarpN ? kWarpIterN / kWarpN : 1;
  377. static constexpr int kIterCount = kIterK * kIterN;
  378. static_assert(kIterCount > 0);
  379. // warp footprint
  380. static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
  381. static constexpr int kWarpFootprintN = kWarpAccessN * kIterN;
  382. // Eliminate bank-conflicts for 8x4 half2 tiles, watch out for misalignment
  383. static constexpr int kSmemPadCtaK = SLICE_K + 8;
  384. static constexpr int kSizePerTile = CTA_N * kSmemPadCtaK;
  385. static constexpr int kSmemByteSize = kElementSize * STAGES * kSizePerTile;
  386. const T_BC* src_;
  387. AccessType* const smem_; // [CTA_N, SLICE_K + 8]
  388. const uint32_t smem_int_ptr_;
  389. const int k_;
  390. const int n_;
  391. const int cta_n_;
  392. const int warp_id_;
  393. const int lane_id_;
  394. const int c_;
  395. const int s_;
  396. int src_offset_n_;
  397. int src_offset_;
  398. int dst_offset_;
  399. int src_step_k_;
  400. int src_step_n_;
  401. int dst_step_k_;
  402. int dst_step_n_;
  403. bool is_valid_n_;
  404. int tmp_src_offset_;
  405. int tmp_dst_offset_;
  406. int tmp_src_offset_n_;
  407. int iter_k_{0};
  408. int iter_n_{0};
  409. IteratorB() = default;
  410. __device__ IteratorB(const T_BC* src, void* smem, int k, int n, int cta_n, int cta_k, int warp_id, int lane_id):
  411. src_(src),
  412. smem_((AccessType*)smem),
  413. smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
  414. k_(k),
  415. n_(n),
  416. cta_n_(cta_n),
  417. warp_id_(warp_id),
  418. lane_id_(lane_id),
  419. c_(lane_id_ % kWarpThreadC),
  420. s_(lane_id_ / kWarpThreadC)
  421. {
  422. const int warp_offset_k = warp_id_ % kWarpK;
  423. const int warp_offset_n = warp_id_ / kWarpK;
  424. const int warp_thread_offset_k = lane_id_ % kWarpThreadC;
  425. const int warp_thread_offset_n = lane_id_ / kWarpThreadC;
  426. const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
  427. const int cta_thread_offset_n = kWarpFootprintN * warp_offset_n + warp_thread_offset_n;
  428. const int src_offset_k = cta_thread_offset_k + cta_k;
  429. src_offset_n_ = cta_thread_offset_n + cta_n_;
  430. src_offset_ = src_offset_n_ * k_ + src_offset_k;
  431. const int dst_offset_k = cta_thread_offset_k;
  432. const int dst_offset_n = cta_thread_offset_n;
  433. dst_offset_ = dst_offset_n * kSmemPadCtaK + dst_offset_k;
  434. src_step_k_ = kWarpAccessK;
  435. src_step_n_ = kWarpAccessN * k_ - kIterK * kWarpAccessK;
  436. dst_step_k_ = kWarpAccessK;
  437. dst_step_n_ = kWarpAccessN * kSmemPadCtaK - kIterK * kWarpAccessK;
  438. dst_offset_ *= kElementSize;
  439. dst_step_k_ *= kElementSize;
  440. dst_step_n_ *= kElementSize;
  441. tmp_src_offset_ = src_offset_;
  442. tmp_dst_offset_ = dst_offset_;
  443. tmp_src_offset_n_ = src_offset_n_;
  444. is_valid_n_ = tmp_src_offset_n_ < n_;
  445. }
  446. __device__ void prefetch_stage(bool mask)
  447. {
  448. PRAGMA_UNROLL
  449. for (int i = 0; i < kIterCount; ++i) {
  450. prefetch(mask);
  451. ++(*this);
  452. }
  453. next_stage();
  454. }
  455. __device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
  456. {
  457. PRAGMA_UNROLL
  458. for (int i = 0; i < batch_size; ++i) {
  459. if (batch_idx * batch_size + i < kIterCount) {
  460. prefetch(mask);
  461. ++(*this);
  462. }
  463. }
  464. }
  465. __device__ IteratorB& operator++()
  466. {
  467. if (!is_valid_n_) {
  468. return *this;
  469. }
  470. // move to next k
  471. tmp_src_offset_ += src_step_k_;
  472. tmp_dst_offset_ += dst_step_k_;
  473. ++iter_k_;
  474. if (iter_k_ < kIterK) {
  475. return *this;
  476. }
  477. // move to next n
  478. iter_k_ = 0;
  479. tmp_src_offset_n_ += kWarpAccessN;
  480. tmp_src_offset_ += src_step_n_;
  481. tmp_dst_offset_ += dst_step_n_;
  482. is_valid_n_ = tmp_src_offset_n_ < n_;
  483. ++iter_n_;
  484. return *this;
  485. }
  486. __device__ void next_stage()
  487. {
  488. iter_n_ = 0;
  489. src_offset_ += CTA_K;
  490. dst_offset_ += kElementSize * kSizePerTile;
  491. if (dst_offset_ >= kSmemByteSize) {
  492. dst_offset_ -= kSmemByteSize;
  493. }
  494. tmp_src_offset_ = src_offset_;
  495. tmp_dst_offset_ = dst_offset_;
  496. tmp_src_offset_n_ = src_offset_n_;
  497. is_valid_n_ = tmp_src_offset_n_ < n_;
  498. }
  499. __device__ void prefetch(bool mask)
  500. {
  501. cp_async_cg_B(
  502. smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
  503. }
  504. };
  505. } // namespace autoquant
  506. } // namespace aphrodite