autogptq_cuda_kernel_64.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <cuda.h>
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6. // atomicAdd for double-precision floating-point numbers on hardware with
  7. // compute capability < 6.0 from:
  8. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
  9. // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
  10. // __device__ double atomicAdd(
  11. // double* address,
  12. // double val
  13. // ) {
  14. // unsigned long long int* address_as_ull = (unsigned long long int*)address;
  15. // unsigned long long int old = *address_as_ull, assumed;
  16. //
  17. // do {
  18. // assumed = old;
  19. // old = atomicCAS(
  20. // address_as_ull,
  21. // assumed,
  22. // __double_as_longlong(val + __longlong_as_double(assumed))
  23. // );
  24. //
  25. // // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  26. // } while (assumed != old);
  27. //
  28. // return __longlong_as_double(old);
  29. // }
  30. // #endif
  31. #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
  32. // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
  33. __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
  34. unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
  35. unsigned int old = *address_as_ui;
  36. unsigned int assumed;
  37. do {
  38. assumed = old;
  39. unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
  40. hsum += val;
  41. old = reinterpret_cast<size_t>(address) & 2
  42. ? (old & 0xffff) | (hsum << 16)
  43. : (old & 0xffff0000) | hsum;
  44. old = atomicCAS(address_as_ui, assumed, old);
  45. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  46. } while (assumed != old);
  47. }
  48. __device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) {
  49. unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
  50. unsigned int old = *address_as_ui;
  51. unsigned int assumed;
  52. do {
  53. assumed = old;
  54. __half_raw hsum;
  55. hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  56. half tmpres = __hadd(hsum, val);
  57. hsum = __half_raw(tmpres);
  58. old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
  59. old = atomicCAS(address_as_ui, assumed, old);
  60. } while (assumed != old);
  61. }
  62. #endif
  63. template <typename scalar_t>
  64. __global__ void VecQuant2MatMulKernel(
  65. const scalar_t* __restrict__ vec,
  66. const int* __restrict__ mat,
  67. scalar_t* __restrict__ mul,
  68. const scalar_t* __restrict__ scales,
  69. const int* __restrict__ zeros,
  70. const int* __restrict__ g_idx,
  71. int batch,
  72. int vec_height,
  73. int height,
  74. int width,
  75. int zero_width
  76. );
  77. template <typename scalar_t>
  78. __global__ void VecQuant3MatMulKernel(
  79. const scalar_t* __restrict__ vec,
  80. const int* __restrict__ mat,
  81. scalar_t* __restrict__ mul,
  82. const scalar_t* __restrict__ scales,
  83. const int* __restrict__ zeros,
  84. const int* __restrict__ g_idx,
  85. int batch,
  86. int vec_height,
  87. int height,
  88. int width,
  89. int zero_width
  90. );
  91. template <typename scalar_t>
  92. __global__ void VecQuant4MatMulKernel(
  93. const scalar_t* __restrict__ vec,
  94. const int* __restrict__ mat,
  95. scalar_t* __restrict__ mul,
  96. const scalar_t* __restrict__ scales,
  97. const int* __restrict__ zeros,
  98. const int* __restrict__ g_idx,
  99. int batch,
  100. int vec_height,
  101. int height,
  102. int width,
  103. int zero_width
  104. );
  105. template <typename scalar_t>
  106. __global__ void VecQuant8MatMulKernel(
  107. const scalar_t* __restrict__ vec,
  108. const int* __restrict__ mat,
  109. scalar_t* __restrict__ mul,
  110. const scalar_t* __restrict__ scales,
  111. const int* __restrict__ zeros,
  112. const int* __restrict__ g_idx,
  113. int batch,
  114. int vec_height,
  115. int height,
  116. int width,
  117. int zero_width
  118. );
  119. template <typename scalar_t>
  120. __global__ void VecQuant2MatMulKernel_old(
  121. const scalar_t* __restrict__ vec,
  122. const int* __restrict__ mat,
  123. scalar_t* __restrict__ mul,
  124. const scalar_t* __restrict__ scales,
  125. const int* __restrict__ zeros,
  126. int batch,
  127. int vec_height,
  128. int height,
  129. int width,
  130. int zero_width,
  131. int groupsize
  132. );
  133. template <typename scalar_t>
  134. __global__ void VecQuant3MatMulKernel_old(
  135. const scalar_t* __restrict__ vec,
  136. const int* __restrict__ mat,
  137. scalar_t* __restrict__ mul,
  138. const scalar_t* __restrict__ scales,
  139. const int* __restrict__ zeros,
  140. int batch,
  141. int vec_height,
  142. int height,
  143. int width,
  144. int zero_width,
  145. int groupsize
  146. );
  147. template <typename scalar_t>
  148. __global__ void VecQuant4MatMulKernel_old(
  149. const scalar_t* __restrict__ vec,
  150. const int* __restrict__ mat,
  151. scalar_t* __restrict__ mul,
  152. const scalar_t* __restrict__ scales,
  153. const int* __restrict__ zeros,
  154. int batch,
  155. int vec_height,
  156. int height,
  157. int width,
  158. int zero_width,
  159. int groupsize
  160. );
  161. template <typename scalar_t>
  162. __global__ void VecQuant8MatMulKernel_old(
  163. const scalar_t* __restrict__ vec,
  164. const int* __restrict__ mat,
  165. scalar_t* __restrict__ mul,
  166. const scalar_t* __restrict__ scales,
  167. const int* __restrict__ zeros,
  168. int batch,
  169. int vec_height,
  170. int height,
  171. int width,
  172. int zero_width,
  173. int groupsize
  174. );
  175. __global__ void VecQuant2MatMulKernelFaster_old(
  176. const half2* __restrict__ vec,
  177. const int* __restrict__ mat,
  178. float* __restrict__ mul,
  179. const float* __restrict__ scales,
  180. const int* __restrict__ zeros,
  181. int batch,
  182. int vec_height,
  183. int height,
  184. int width,
  185. int zero_width,
  186. int groupsize
  187. );
  188. __global__ void VecQuant3MatMulKernelFaster_old(
  189. const half2* __restrict__ vec,
  190. const int* __restrict__ mat,
  191. float* __restrict__ mul,
  192. const float* __restrict__ scales,
  193. const int* __restrict__ zeros,
  194. int batch,
  195. int vec_height,
  196. int height,
  197. int width,
  198. int zero_width,
  199. int groupsize
  200. );
  201. __global__ void VecQuant4MatMulKernelFaster_old(
  202. const half2* __restrict__ vec,
  203. const int* __restrict__ mat,
  204. float* __restrict__ mul,
  205. const float* __restrict__ scales,
  206. const int* __restrict__ zeros,
  207. int batch,
  208. int vec_height,
  209. int height,
  210. int width,
  211. int zero_width,
  212. int groupsize
  213. );
  214. const int BLOCKWIDTH = 64;
  215. const int BLOCKHEIGHT2 = 4;
  216. const int BLOCKHEIGHT3 = 6;
  217. const int BLOCKHEIGHT4 = 8;
  218. const int BLOCKHEIGHT8 = 16;
  219. __device__ inline unsigned int as_unsigned(int i) {
  220. return *reinterpret_cast<unsigned int*>(&i);
  221. }
  222. __device__ inline int as_int(int i) {
  223. return *reinterpret_cast<int*>(&i);
  224. }
  225. void vecquant2matmul_cuda(
  226. torch::Tensor vec,
  227. torch::Tensor mat,
  228. torch::Tensor mul,
  229. torch::Tensor scales,
  230. torch::Tensor zeros,
  231. torch::Tensor g_idx
  232. ) {
  233. int batch = vec.size(0);
  234. int vec_height = vec.size(1);
  235. int height = mat.size(0);
  236. int width = mat.size(1);
  237. int zero_width = zeros.size(1);
  238. dim3 blocks(
  239. (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
  240. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  241. );
  242. dim3 threads(BLOCKWIDTH);
  243. AT_DISPATCH_FLOATING_TYPES(
  244. vec.type(), "vecquant2matmul_cuda", ([&] {
  245. VecQuant2MatMulKernel<<<blocks, threads>>>(
  246. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  247. scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
  248. batch, vec_height, height, width, zero_width
  249. );
  250. })
  251. );
  252. }
  253. template <typename scalar_t>
  254. __global__ void VecQuant2MatMulKernel(
  255. const scalar_t* __restrict__ vec,
  256. const int* __restrict__ mat,
  257. scalar_t* __restrict__ mul,
  258. const scalar_t* __restrict__ scales,
  259. const int* __restrict__ zeros,
  260. const int* __restrict__ g_idx,
  261. int batch,
  262. int vec_height,
  263. int height,
  264. int width,
  265. int zero_width
  266. ) {
  267. int h = BLOCKHEIGHT2 * blockIdx.x;
  268. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  269. __shared__ scalar_t blockvec[BLOCKWIDTH];
  270. int i = width * h + w;
  271. int g_h = h * 16;
  272. int k;
  273. unsigned int g;
  274. scalar_t w_tmp;
  275. int z_w = w / 16;
  276. int z_mod = (w % 16) * 2;
  277. float weight[BLOCKWIDTH];
  278. for (k = 0; k < BLOCKWIDTH; ++k){
  279. int k_w = (k / 16);
  280. int k_bit = (k % 16) * 2;
  281. g = as_int(g_idx[g_h + k]);
  282. scalar_t scale = scales[g * width + w];
  283. scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
  284. w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
  285. weight[k] = scale * (w_tmp - zero);
  286. }
  287. scalar_t res;
  288. for (int b = 0; b < batch; ++b){
  289. res = 0;
  290. blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
  291. __syncthreads();
  292. for (k = 0; k < BLOCKWIDTH; ++k){
  293. res += weight[k] * blockvec[k];
  294. }
  295. atomicAdd(&mul[b * width + w], res);
  296. __syncthreads();
  297. }
  298. }
  299. void vecquant3matmul_cuda(
  300. torch::Tensor vec,
  301. torch::Tensor mat,
  302. torch::Tensor mul,
  303. torch::Tensor scales,
  304. torch::Tensor zeros,
  305. torch::Tensor g_idx
  306. ) {
  307. int batch = vec.size(0);
  308. int vec_height = vec.size(1);
  309. int height = mat.size(0);
  310. int width = mat.size(1);
  311. int zero_width = zeros.size(1);
  312. dim3 blocks(
  313. (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
  314. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  315. );
  316. dim3 threads(BLOCKWIDTH);
  317. AT_DISPATCH_FLOATING_TYPES(
  318. vec.type(), "vecquant3matmul_cuda", ([&] {
  319. VecQuant3MatMulKernel<<<blocks, threads>>>(
  320. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  321. scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
  322. batch, vec_height, height, width, zero_width
  323. );
  324. })
  325. );
  326. }
  327. template <typename scalar_t>
  328. __global__ void VecQuant3MatMulKernel(
  329. const scalar_t* __restrict__ vec,
  330. const int* __restrict__ mat,
  331. scalar_t* __restrict__ mul,
  332. const scalar_t* __restrict__ scales,
  333. const int* __restrict__ zeros,
  334. const int* __restrict__ g_idx,
  335. int batch,
  336. int vec_height,
  337. int height,
  338. int width,
  339. int zero_width
  340. ) {
  341. int h = BLOCKHEIGHT3 * blockIdx.x;
  342. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  343. __shared__ scalar_t blockvec[BLOCKWIDTH];
  344. int i = width * h + w;
  345. int g_h = (h / 3) * 32;
  346. int k;
  347. unsigned int g;
  348. scalar_t w_tmp;
  349. int z_w = (w / 32) * 3;
  350. int z_mod = w % 32;
  351. int z_bit;
  352. unsigned int z_tmp;
  353. if (z_mod != 10){
  354. if (z_mod != 21){
  355. z_bit = z_mod;
  356. if (z_bit > 21){
  357. z_bit -= 22;
  358. z_bit *= 3;
  359. z_bit += 2;
  360. z_w += 2;
  361. } else if (z_bit > 10){
  362. z_bit -= 11;
  363. z_bit *= 3;
  364. z_bit += 1;
  365. z_w += 1;
  366. } else {
  367. z_bit *= 3;
  368. }
  369. } else {
  370. z_w += 1;
  371. }
  372. }
  373. float weight[BLOCKWIDTH];
  374. for (k = 0; k < BLOCKWIDTH; ++k){
  375. int k_w = (k / 32) * 3;
  376. int k_mod = k % 32;
  377. int k_bit;
  378. if (k_mod != 10){
  379. if (k_mod != 21){
  380. k_bit = k_mod;
  381. if (k_bit > 21){
  382. k_bit -= 22;
  383. k_bit *= 3;
  384. k_bit += 2;
  385. k_w += 2;
  386. } else if (k_bit > 10){
  387. k_bit -= 11;
  388. k_bit *= 3;
  389. k_bit += 1;
  390. k_w += 1;
  391. } else {
  392. k_bit *= 3;
  393. }
  394. } else {
  395. k_w += 1;
  396. }
  397. }
  398. g = as_int(g_idx[g_h + k]);
  399. scalar_t scale = scales[g * width + w];
  400. scalar_t zero;
  401. if (z_mod == 10) {
  402. z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
  403. zero = scalar_t((z_tmp) + 1);
  404. } else if (z_mod == 21){
  405. z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
  406. zero = scalar_t((z_tmp) + 1);
  407. } else {
  408. zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
  409. }
  410. if (k_mod == 10) {
  411. w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
  412. } else if (k_mod == 21){
  413. w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6);
  414. } else {
  415. w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7);
  416. }
  417. weight[k] = scale * (w_tmp - zero);
  418. }
  419. scalar_t res;
  420. for (int b = 0; b < batch; ++b){
  421. res = 0;
  422. blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
  423. __syncthreads();
  424. for (k = 0; k < BLOCKWIDTH; ++k){
  425. res += weight[k] * blockvec[k];
  426. }
  427. atomicAdd(&mul[b * width + w], res);
  428. __syncthreads();
  429. }
  430. }
  431. void vecquant4matmul_cuda(
  432. torch::Tensor vec,
  433. torch::Tensor mat,
  434. torch::Tensor mul,
  435. torch::Tensor scales,
  436. torch::Tensor zeros,
  437. torch::Tensor g_idx
  438. ) {
  439. int batch = vec.size(0);
  440. int vec_height = vec.size(1);
  441. int height = mat.size(0);
  442. int width = mat.size(1);
  443. int zero_width = zeros.size(1);
  444. dim3 blocks(
  445. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  446. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  447. );
  448. dim3 threads(BLOCKWIDTH);
  449. AT_DISPATCH_FLOATING_TYPES(
  450. vec.type(), "vecquant4matmul_cuda", ([&] {
  451. VecQuant4MatMulKernel<<<blocks, threads>>>(
  452. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  453. scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
  454. batch, vec_height, height, width, zero_width
  455. );
  456. })
  457. );
  458. }
  459. template <typename scalar_t>
  460. __global__ void VecQuant4MatMulKernel(
  461. const scalar_t* __restrict__ vec,
  462. const int* __restrict__ mat,
  463. scalar_t* __restrict__ mul,
  464. const scalar_t* __restrict__ scales,
  465. const int* __restrict__ zeros,
  466. const int* __restrict__ g_idx,
  467. int batch,
  468. int vec_height,
  469. int height,
  470. int width,
  471. int zero_width
  472. ) {
  473. int h = BLOCKHEIGHT4 * blockIdx.x;
  474. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  475. __shared__ scalar_t blockvec[BLOCKWIDTH];
  476. int i = width * h + w;
  477. int g_h = h * 8;
  478. int k;
  479. unsigned int g;
  480. scalar_t w_tmp;
  481. int z_w = w / 8;
  482. int z_mod = (w % 8) * 4;
  483. float weight[BLOCKWIDTH];
  484. for (k = 0; k < BLOCKWIDTH; ++k){
  485. int k_w = (k / 8);
  486. int k_bit = (k % 8) * 4;
  487. g = as_int(g_idx[g_h + k]);
  488. scalar_t scale = scales[g * width + w];
  489. scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
  490. w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
  491. weight[k] = scale * (w_tmp - zero);
  492. }
  493. scalar_t res;
  494. for (int b = 0; b < batch; ++b){
  495. res = 0;
  496. blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
  497. __syncthreads();
  498. for (k = 0; k < BLOCKWIDTH; ++k){
  499. res += weight[k] * blockvec[k];
  500. }
  501. atomicAdd(&mul[b * width + w], res);
  502. __syncthreads();
  503. }
  504. }
  505. void vecquant8matmul_cuda(
  506. torch::Tensor vec,
  507. torch::Tensor mat,
  508. torch::Tensor mul,
  509. torch::Tensor scales,
  510. torch::Tensor zeros,
  511. torch::Tensor g_idx
  512. ) {
  513. int batch = vec.size(0);
  514. int vec_height = vec.size(1);
  515. int height = mat.size(0);
  516. int width = mat.size(1);
  517. int zero_width = zeros.size(1);
  518. dim3 blocks(
  519. (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
  520. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  521. );
  522. dim3 threads(BLOCKWIDTH);
  523. AT_DISPATCH_FLOATING_TYPES(
  524. vec.type(), "vecquant8matmul_cuda", ([&] {
  525. VecQuant8MatMulKernel<<<blocks, threads>>>(
  526. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  527. scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
  528. batch, vec_height, height, width, zero_width
  529. );
  530. })
  531. );
  532. }
  533. template <typename scalar_t>
  534. __global__ void VecQuant8MatMulKernel(
  535. const scalar_t* __restrict__ vec,
  536. const int* __restrict__ mat,
  537. scalar_t* __restrict__ mul,
  538. const scalar_t* __restrict__ scales,
  539. const int* __restrict__ zeros,
  540. const int* __restrict__ g_idx,
  541. int batch,
  542. int vec_height,
  543. int height,
  544. int width,
  545. int zero_width
  546. ) {
  547. int h = BLOCKHEIGHT8 * blockIdx.x;
  548. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  549. __shared__ scalar_t blockvec[BLOCKWIDTH];
  550. int i = width * h + w;
  551. int g_h = h * 4;
  552. int k;
  553. unsigned int g;
  554. scalar_t w_tmp;
  555. int z_w = w / 4;
  556. int z_mod = (w % 4) * 8;
  557. float weight[BLOCKWIDTH];
  558. for (k = 0; k < BLOCKWIDTH; ++k){
  559. int k_w = (k / 4);
  560. int k_bit = (k % 4) * 8;
  561. g = as_int(g_idx[g_h + k]);
  562. scalar_t scale = scales[g * width + w];
  563. scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
  564. w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
  565. weight[k] = scale * (w_tmp - zero);
  566. }
  567. scalar_t res;
  568. for (int b = 0; b < batch; ++b){
  569. res = 0;
  570. blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
  571. __syncthreads();
  572. for (k = 0; k < BLOCKWIDTH; ++k){
  573. res += weight[k] * blockvec[k];
  574. }
  575. atomicAdd(&mul[b * width + w], res);
  576. __syncthreads();
  577. }
  578. }