1
0

linear.py 20 KB


  1. # Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
  2. # and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
  3. from typing import Optional
  4. import torch
  5. import triton
  6. import triton.language as tl
  7. from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
  8. from flash_attn.ops.triton.k_activations import (
  9. gelu,
  10. gelu_approx,
  11. gelu_approx_grad,
  12. gelu_grad,
  13. squared_relu,
  14. squared_relu_grad,
  15. )
  16. # CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
  17. def init_to_zero(name):
  18. return lambda nargs: nargs[name].zero_()
  19. def get_configs_io_bound():
  20. configs = []
  21. for num_stages in [2, 3, 4, 5, 6]:
  22. for block_m in [16, 32]:
  23. for block_k in [32, 64]:
  24. for block_n in [32, 64, 128, 256]:
  25. num_warps = 2 if block_n <= 64 else 4
  26. configs.append(
  27. triton.Config(
  28. {
  29. "BLOCK_M": block_m,
  30. "BLOCK_N": block_n,
  31. "BLOCK_K": block_k,
  32. "SPLIT_K": 1,
  33. },
  34. num_stages=num_stages,
  35. num_warps=num_warps,
  36. )
  37. )
  38. # split_k not used
  39. # for split_k in [2, 4, 8, 16]:
  40. # configs.append(triton.Config(
  41. # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
  42. # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
  43. return configs
  44. @triton.autotune(
  45. configs=[
  46. triton.Config(
  47. {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
  48. ),
  49. triton.Config(
  50. {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
  51. ),
  52. triton.Config(
  53. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  54. ),
  55. triton.Config(
  56. {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  57. ),
  58. triton.Config(
  59. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  60. ),
  61. triton.Config(
  62. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  63. ),
  64. triton.Config(
  65. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  66. ),
  67. triton.Config(
  68. {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  69. ),
  70. triton.Config(
  71. {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
  72. ),
  73. # good for int8
  74. triton.Config(
  75. {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
  76. num_stages=3,
  77. num_warps=8,
  78. ),
  79. triton.Config(
  80. {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
  81. num_stages=3,
  82. num_warps=8,
  83. ),
  84. triton.Config(
  85. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
  86. ),
  87. triton.Config(
  88. {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
  89. ),
  90. triton.Config(
  91. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
  92. num_stages=4,
  93. num_warps=4,
  94. ),
  95. triton.Config(
  96. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  97. ),
  98. triton.Config(
  99. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  100. ),
  101. triton.Config(
  102. {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  103. ),
  104. triton.Config(
  105. {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
  106. ),
  107. ]
  108. + get_configs_io_bound(),
  109. key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
  110. prune_configs_by={
  111. "early_config_prune": early_config_prune,
  112. "perf_model": estimate_matmul_time,
  113. "top_k": 10,
  114. },
  115. )
  116. @triton.heuristics(
  117. {
  118. "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
  119. }
  120. )
  121. @triton.jit
  122. def kernel_fwd(
  123. C, # Pointers to matrices
  124. ACT_INPUT,
  125. A,
  126. B,
  127. bias,
  128. # Matrix dimensions
  129. M,
  130. N,
  131. K,
  132. CACHE_KEY_M,
  133. CACHE_KEY_N,
  134. CACHE_KEY_K,
  135. # The stride variables represent how much to increase the ptr by when moving by 1
  136. # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
  137. # by to get the element one row down (A has M rows)
  138. stride_cm,
  139. # stride_cn, # Assume that stride_cn == 1
  140. stride_am,
  141. stride_ak,
  142. stride_bn,
  143. stride_bk,
  144. # Meta-parameters
  145. BLOCK_M: tl.constexpr,
  146. GROUP_M: tl.constexpr,
  147. BLOCK_N: tl.constexpr,
  148. BLOCK_K: tl.constexpr,
  149. # split k not used, not performant with activation, kept because early_config_prune is expecting it
  150. SPLIT_K: tl.constexpr,
  151. EVEN_K: tl.constexpr,
  152. A_ROWMAJOR: tl.constexpr,
  153. B_COLMAJOR: tl.constexpr,
  154. BIAS: tl.constexpr,
  155. SAVE_ACT_INPUT: tl.constexpr,
  156. ACTIVATION: tl.constexpr,
  157. ):
  158. """
  159. Kernel for computing Out = activation(A x W + C)
  160. - Input has shape (M, K)
  161. - Weight has shape (K, N)
  162. - Bias has shape (N,)
  163. - Output has shape (M, N)
  164. - ActInputs (optional) has shape (M, N)
  165. 'ActInputs' optionally saves the A x W + C intermediate for backward computations
  166. This kernel will consolidate over K
  167. """
  168. pid = tl.program_id(axis=0)
  169. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  170. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  171. # re-order program ID for better L2 performance
  172. width = GROUP_M * grid_n
  173. group_id = pid // width
  174. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  175. pid_m = group_id * GROUP_M + (pid % group_size)
  176. pid_n = (pid % width) // (group_size)
  177. # now compute the block that each program will go through
  178. # rm (resp. rn) denotes a range of indices
  179. # for rows (resp. col) of C
  180. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  181. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  182. # trick to avoid masking on M and N axis
  183. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  184. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  185. rk = tl.arange(0, BLOCK_K)
  186. if A_ROWMAJOR:
  187. A = A + (ram[:, None] * stride_am + rk[None, :])
  188. else:
  189. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
  190. if B_COLMAJOR:
  191. B = B + (rk[:, None] + rbn[None, :] * stride_bn)
  192. else:
  193. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
  194. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  195. for k in range(K, 0, -BLOCK_K):
  196. if EVEN_K:
  197. a = tl.load(A)
  198. b = tl.load(B)
  199. else:
  200. a = tl.load(A, mask=rk[None, :] < k, other=0.0)
  201. b = tl.load(B, mask=rk[:, None] < k, other=0.0)
  202. acc += tl.dot(a, b)
  203. if A_ROWMAJOR:
  204. A += BLOCK_K
  205. else:
  206. A += BLOCK_K * stride_ak
  207. if B_COLMAJOR:
  208. B += BLOCK_K
  209. else:
  210. B += BLOCK_K * stride_bk
  211. # Putting bias after the matmul (instead of before) is faster, idk why
  212. if BIAS:
  213. bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
  214. acc += bias[None, :]
  215. # optional: save the activation inputs
  216. if SAVE_ACT_INPUT:
  217. # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
  218. act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
  219. tl.store(act_in_ptrs, acc)
  220. # optional: fused activation (while the data is in shared memory)
  221. if ACTIVATION == "gelu":
  222. acc = gelu(acc)
  223. elif ACTIVATION == "gelu_approx":
  224. acc = gelu_approx(acc)
  225. elif ACTIVATION == "squared_relu":
  226. acc = squared_relu(acc)
  227. # rematerialize rm and rn to save registers
  228. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  229. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  230. # write back result
  231. # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
  232. C = C + rm[:, None] * stride_cm + rn[None, :]
  233. mask = (rm < M)[:, None] & (rn < N)[None, :]
  234. tl.store(C, acc)
  235. def triton_linear_act(
  236. x: torch.Tensor,
  237. weight: torch.Tensor,
  238. bias: Optional[torch.Tensor] = None,
  239. activation: str = "id",
  240. save_act_input: bool = False,
  241. ) -> torch.Tensor:
  242. """
  243. Compute e = activation(x @ weight.T + bias).
  244. This wrapper kicks the `kernel_fwd` Triton kernel
  245. :param x: input tensor
  246. :param weight: weight matrix
  247. :param bias: an optional bias tensor
  248. :param activation: Activation name. Needs to be a Triton kernel.
  249. :param act_input: an optional tensor to save the activation inputs (for backward)
  250. :return: result tensor
  251. """
  252. # if torch.is_autocast_enabled():
  253. # dtype = torch.get_autocast_gpu_dtype()
  254. # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
  255. assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
  256. batch_shape, n = x.shape[:-1], x.shape[-1]
  257. batch_dim = batch_shape.numel()
  258. x_reshaped = x.reshape(batch_dim, n)
  259. if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
  260. x_reshaped = x_reshaped.contiguous()
  261. if weight.stride(0) > 1 and weight.stride(1) > 1:
  262. weight = weight.contiguous()
  263. bias = bias.contiguous() if bias is not None else None
  264. assert (
  265. x.dtype == weight.dtype
  266. ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
  267. if bias is not None:
  268. assert (
  269. x.dtype == bias.dtype
  270. ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
  271. assert (
  272. x_reshaped.shape[1] == weight.shape[1]
  273. ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
  274. assert (
  275. bias is None or bias.shape[0] == weight.shape[0]
  276. ), "Incompatible dimensions in between weight and bias"
  277. M, K = x_reshaped.shape
  278. N, K = weight.shape
  279. output = torch.empty((M, N), device=x.device, dtype=x.dtype)
  280. act_input = torch.empty_like(output) if save_act_input else None
  281. # 1D launch kernel where each block gets its own program.
  282. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
  283. kernel_fwd[grid](
  284. output,
  285. act_input,
  286. x_reshaped,
  287. weight, # data ptrs
  288. bias if bias is not None else x, # auto skip bias if not present
  289. M, # shapes
  290. N,
  291. K,
  292. M // 32, # key for triton cache (limit number of compilations)
  293. N // 32,
  294. K // 32,
  295. stride_cm=output.stride(0), # strides
  296. # stride_cn=output.stride(1),
  297. stride_am=x_reshaped.stride(0),
  298. stride_ak=x_reshaped.stride(1),
  299. stride_bk=weight.stride(1),
  300. stride_bn=weight.stride(0),
  301. BIAS=bias is not None, # optional fused bias
  302. SAVE_ACT_INPUT=save_act_input, # optional save activation inputs
  303. ACTIVATION=activation, # optional fused activation
  304. A_ROWMAJOR=x_reshaped.stride(1) == 1,
  305. B_COLMAJOR=weight.stride(1) == 1,
  306. GROUP_M=8, # speed optimization: group the programs
  307. )
  308. if not save_act_input:
  309. return output.reshape(*batch_shape, output.shape[-1])
  310. else:
  311. return (
  312. output.reshape(*batch_shape, output.shape[-1]),
  313. act_input.reshape(*batch_shape, act_input.shape[-1]),
  314. )
  315. @triton.autotune(
  316. configs=[
  317. triton.Config(
  318. {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
  319. ),
  320. triton.Config(
  321. {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
  322. ),
  323. triton.Config(
  324. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  325. ),
  326. triton.Config(
  327. {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  328. ),
  329. triton.Config(
  330. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  331. ),
  332. triton.Config(
  333. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  334. ),
  335. triton.Config(
  336. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  337. ),
  338. triton.Config(
  339. {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
  340. ),
  341. triton.Config(
  342. {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
  343. ),
  344. # good for int8
  345. triton.Config(
  346. {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
  347. num_stages=3,
  348. num_warps=8,
  349. ),
  350. triton.Config(
  351. {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
  352. num_stages=3,
  353. num_warps=8,
  354. ),
  355. triton.Config(
  356. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
  357. ),
  358. triton.Config(
  359. {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
  360. ),
  361. triton.Config(
  362. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
  363. num_stages=4,
  364. num_warps=4,
  365. ),
  366. triton.Config(
  367. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  368. ),
  369. triton.Config(
  370. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  371. ),
  372. triton.Config(
  373. {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
  374. ),
  375. triton.Config(
  376. {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
  377. ),
  378. ]
  379. + get_configs_io_bound(),
  380. key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
  381. prune_configs_by={
  382. "early_config_prune": early_config_prune,
  383. "perf_model": estimate_matmul_time,
  384. "top_k": 10,
  385. },
  386. )
  387. @triton.heuristics(
  388. {
  389. "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
  390. }
  391. )
  392. @triton.jit
  393. def kernel_bwd(
  394. C, # Pointers to matrices
  395. ACT_INPUT,
  396. A,
  397. B,
  398. # Matrix dimensions
  399. M,
  400. N,
  401. K,
  402. CACHE_KEY_M,
  403. CACHE_KEY_N,
  404. CACHE_KEY_K,
  405. # The stride variables represent how much to increase the ptr by when moving by 1
  406. # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
  407. # by to get the element one row down (A has M rows)
  408. stride_cm,
  409. # stride_cn, # Assume that stride_cn == 1
  410. stride_am,
  411. stride_ak,
  412. stride_bk,
  413. stride_bn,
  414. # Meta-parameters
  415. BLOCK_M: tl.constexpr,
  416. GROUP_M: tl.constexpr,
  417. BLOCK_N: tl.constexpr,
  418. BLOCK_K: tl.constexpr,
  419. # split k not used, not performant with activation, kept because early_config_prune is expecting it
  420. SPLIT_K: tl.constexpr,
  421. EVEN_K: tl.constexpr,
  422. ACTIVATION: tl.constexpr,
  423. ):
  424. """
  425. Kernel for computing Out = activation(A x W + C)
  426. - Input has shape (M, K)
  427. - Weight has shape (K, N)
  428. - Output has shape (M, N)
  429. - ActInputs (optional) has shape (M, N)
  430. 'ActInputs' optionally saves the A x W + C intermediate for backward computations
  431. This kernel will consolidate over K
  432. """
  433. pid = tl.program_id(axis=0)
  434. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  435. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  436. # re-order program ID for better L2 performance
  437. width = GROUP_M * grid_n
  438. group_id = pid // width
  439. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  440. pid_m = group_id * GROUP_M + (pid % group_size)
  441. pid_n = (pid % width) // (group_size)
  442. # now compute the block that each program will go through
  443. # rm (resp. rn) denotes a range of indices
  444. # for rows (resp. col) of C
  445. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  446. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  447. # trick to avoid masking on M and N axis
  448. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  449. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  450. rk = tl.arange(0, BLOCK_K)
  451. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
  452. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
  453. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  454. for k in range(K, 0, -BLOCK_K):
  455. if EVEN_K:
  456. a = tl.load(A)
  457. b = tl.load(B)
  458. else:
  459. a = tl.load(A, mask=rk[None, :] < k, other=0.0)
  460. b = tl.load(B, mask=rk[:, None] < k, other=0.0)
  461. acc += tl.dot(a, b)
  462. A += BLOCK_K * stride_ak
  463. B += BLOCK_K * stride_bk
  464. # optional: fused activation (while the data is in shared memory)
  465. if ACTIVATION != "id":
  466. act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
  467. act_input = tl.load(act_in_ptrs).to(acc.dtype)
  468. if ACTIVATION == "gelu":
  469. acc *= gelu_grad(act_input)
  470. elif ACTIVATION == "gelu_approx":
  471. acc *= gelu_approx_grad(act_input)
  472. elif ACTIVATION == "squared_relu":
  473. acc *= squared_relu_grad(act_input)
  474. # rematerialize rm and rn to save registers
  475. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  476. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  477. # write back result
  478. C = C + rm[:, None] * stride_cm + rn[None, :]
  479. mask = (rm < M)[:, None] & (rn < N)[None, :]
  480. tl.store(C, acc, mask=mask)
  481. def triton_dgrad_act(
  482. grad_output: torch.Tensor,
  483. weight: torch.Tensor,
  484. activation: str = "id",
  485. act_input: Optional[torch.Tensor] = None,
  486. ) -> torch.Tensor:
  487. """
  488. Compute e = activation(grad_output @ weight + bias).
  489. This wrapper kicks the `kernel_fwd` Triton kernel
  490. :param grad_output: input tensor
  491. :param weight: weight matrix
  492. :param activation: Activation name. Needs to be a Triton kernel.
  493. :param act_input: an optional tensor to save the activation inputs (for backward)
  494. :return: result tensor
  495. """
  496. assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
  497. batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
  498. batch_dim = batch_shape.numel()
  499. grad_output_reshaped = grad_output.reshape(batch_dim, n)
  500. if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
  501. grad_output_reshaped = grad_output_reshaped.contiguous()
  502. if weight.stride(0) > 1 and weight.stride(1) > 1:
  503. weight = weight.contiguous()
  504. assert (
  505. grad_output.dtype == weight.dtype
  506. ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
  507. assert (
  508. grad_output_reshaped.shape[1] == weight.shape[0]
  509. ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
  510. if activation != "id":
  511. assert act_input is not None, f"act_input is required for activation {activation}"
  512. # M, N, K in bwd are different from M, N, K in fwd
  513. M, K = grad_output_reshaped.shape
  514. K, N = weight.shape
  515. grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)
  516. # 1D launch kernel where each block gets its own program.
  517. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
  518. kernel_bwd[grid](
  519. grad_input,
  520. act_input,
  521. grad_output_reshaped,
  522. weight, # data ptrs
  523. M, # shapes
  524. N,
  525. K,
  526. M // 32, # key for triton cache (limit number of compilations)
  527. N // 32,
  528. K // 32,
  529. stride_cm=grad_input.stride(0), # strides
  530. # stride_cn=grad_input.stride(1),
  531. stride_am=grad_output_reshaped.stride(0),
  532. stride_ak=grad_output_reshaped.stride(1),
  533. stride_bk=weight.stride(0),
  534. stride_bn=weight.stride(1),
  535. ACTIVATION=activation, # optional fused activation
  536. GROUP_M=8, # speed optimization: group the programs
  537. )
  538. return grad_input.reshape(*batch_shape, grad_input.shape[-1])