|
@@ -0,0 +1,479 @@
|
|
|
+# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
|
|
|
+# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+import torch
|
|
|
+import triton
|
|
|
+import triton.language as tl
|
|
|
+from torch.autograd.function import FunctionCtx
|
|
|
+from torch.cuda.amp import custom_fwd
|
|
|
+from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
|
+
|
|
|
+from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
|
|
|
+
|
|
|
+
|
|
|
+# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
|
|
|
+
|
|
|
+
|
|
|
+def init_to_zero(name):
|
|
|
+ return lambda nargs: nargs[name].zero_()
|
|
|
+
|
|
|
+
|
|
|
+def get_configs_io_bound():
|
|
|
+ configs = []
|
|
|
+ for num_stages in [2, 3, 4, 5, 6]:
|
|
|
+ for block_m in [16, 32]:
|
|
|
+ for block_k in [32, 64]:
|
|
|
+ for block_n in [32, 64, 128, 256]:
|
|
|
+ num_warps = 2 if block_n <= 64 else 4
|
|
|
+ configs.append(
|
|
|
+ triton.Config(
|
|
|
+ {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
|
|
|
+ num_stages=num_stages,
|
|
|
+ num_warps=num_warps,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ # split_k not used
|
|
|
+ # for split_k in [2, 4, 8, 16]:
|
|
|
+ # configs.append(triton.Config(
|
|
|
+ # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
|
|
+ # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
|
|
+ return configs
|
|
|
+
|
|
|
+
|
|
|
+@triton.autotune(
|
|
|
+ configs=[
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
|
|
+ # good for int8
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
|
|
+ ]
|
|
|
+ + get_configs_io_bound(),
|
|
|
+ key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
|
|
+ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
|
|
+)
|
|
|
+@triton.heuristics(
|
|
|
+ {
|
|
|
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
|
+ }
|
|
|
+)
|
|
|
+@triton.jit
|
|
|
+def kernel_fwd(
|
|
|
+ C, # Pointers to matrices
|
|
|
+ ACT_INPUT,
|
|
|
+ A,
|
|
|
+ B,
|
|
|
+ bias,
|
|
|
+ # Matrix dimensions
|
|
|
+ M,
|
|
|
+ N,
|
|
|
+ K,
|
|
|
+ CACHE_KEY_M,
|
|
|
+ CACHE_KEY_N,
|
|
|
+ CACHE_KEY_K,
|
|
|
+ # The stride variables represent how much to increase the ptr by when moving by 1
|
|
|
+ # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
|
|
+ # by to get the element one row down (A has M rows)
|
|
|
+ stride_cm,
|
|
|
+ # stride_cn, # Assume that stride_cn == 1
|
|
|
+ stride_am,
|
|
|
+ stride_ak,
|
|
|
+ stride_bn,
|
|
|
+ stride_bk,
|
|
|
+ # Meta-parameters
|
|
|
+ BLOCK_M: tl.constexpr,
|
|
|
+ GROUP_M: tl.constexpr,
|
|
|
+ BLOCK_N: tl.constexpr,
|
|
|
+ BLOCK_K: tl.constexpr,
|
|
|
+ # split k not used, not performant with activation, kept because early_config_prune is expecting it
|
|
|
+ SPLIT_K: tl.constexpr,
|
|
|
+ EVEN_K: tl.constexpr,
|
|
|
+ A_ROWMAJOR: tl.constexpr,
|
|
|
+ B_COLMAJOR: tl.constexpr,
|
|
|
+ BIAS: tl.constexpr,
|
|
|
+ SAVE_ACT_INPUT: tl.constexpr,
|
|
|
+ ACTIVATION: tl.constexpr,
|
|
|
+):
|
|
|
+
|
|
|
+ """
|
|
|
+ Kernel for computing Out = activation(A x W + C)
|
|
|
+ - Input has shape (M, K)
|
|
|
+ - Weight has shape (K, N)
|
|
|
+ - Bias has shape (N,)
|
|
|
+ - Output has shape (M, N)
|
|
|
+ - ActInputs (optional) has shape (M, N)
|
|
|
+ 'ActInputs' optionally saves the A x W + C intermediate for backward computations
|
|
|
+ This kernel will consolidate over K
|
|
|
+ """
|
|
|
+
|
|
|
+ pid = tl.program_id(axis=0)
|
|
|
+
|
|
|
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
|
|
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
|
|
+ # re-order program ID for better L2 performance
|
|
|
+ width = GROUP_M * grid_n
|
|
|
+ group_id = pid // width
|
|
|
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
|
+ pid_m = group_id * GROUP_M + (pid % group_size)
|
|
|
+ pid_n = (pid % width) // (group_size)
|
|
|
+
|
|
|
+ # now compute the block that each program will go through
|
|
|
+ # rm (resp. rn) denotes a range of indices
|
|
|
+ # for rows (resp. col) of C
|
|
|
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+ # trick to avoid masking on M and N axis
|
|
|
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
|
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
|
+ rk = tl.arange(0, BLOCK_K)
|
|
|
+
|
|
|
+ if A_ROWMAJOR:
|
|
|
+ A = A + (ram[:, None] * stride_am + rk[None, :])
|
|
|
+ else:
|
|
|
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
|
+ if B_COLMAJOR:
|
|
|
+ B = B + (rk[:, None] + rbn[None, :] * stride_bn)
|
|
|
+ else:
|
|
|
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
|
+
|
|
|
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
+
|
|
|
+ for k in range(K, 0, -BLOCK_K):
|
|
|
+ if EVEN_K:
|
|
|
+ a = tl.load(A)
|
|
|
+ b = tl.load(B)
|
|
|
+ else:
|
|
|
+ a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
|
|
+ b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
|
|
+ acc += tl.dot(a, b)
|
|
|
+
|
|
|
+ if A_ROWMAJOR:
|
|
|
+ A += BLOCK_K
|
|
|
+ else:
|
|
|
+ A += BLOCK_K * stride_ak
|
|
|
+ if B_COLMAJOR:
|
|
|
+ B += BLOCK_K
|
|
|
+ else:
|
|
|
+ B += BLOCK_K * stride_bk
|
|
|
+
|
|
|
+ # Putting bias after the matmul (instead of before) is faster, idk why
|
|
|
+ if BIAS:
|
|
|
+ bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
|
|
|
+ acc += bias[None, :]
|
|
|
+
|
|
|
+ # optional: save the activation inputs
|
|
|
+ if SAVE_ACT_INPUT:
|
|
|
+ # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
|
|
|
+ act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
|
|
|
+ tl.store(act_in_ptrs, acc)
|
|
|
+
|
|
|
+ # optional: fused activation (while the data is in shared memory)
|
|
|
+ if ACTIVATION == "gelu":
|
|
|
+ acc = gelu(acc)
|
|
|
+ elif ACTIVATION == "gelu_approx":
|
|
|
+ acc = gelu_approx(acc)
|
|
|
+ elif ACTIVATION == "squared_relu":
|
|
|
+ acc = squared_relu(acc)
|
|
|
+ # rematerialize rm and rn to save registers
|
|
|
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+
|
|
|
+ # write back result
|
|
|
+ # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
|
|
|
+ C = C + rm[:, None] * stride_cm + rn[None, :]
|
|
|
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
|
+ tl.store(C, acc)
|
|
|
+
|
|
|
+
|
|
|
+def triton_linear_act(
|
|
|
+ x: torch.Tensor,
|
|
|
+ weight: torch.Tensor,
|
|
|
+ bias: Optional[torch.Tensor] = None,
|
|
|
+ activation: str = 'id',
|
|
|
+ save_act_input: bool = False,
|
|
|
+) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Compute e = activation(x @ weight.T + bias).
|
|
|
+ This wrapper kicks the `kernel_fwd` Triton kernel
|
|
|
+ :param x: input tensor
|
|
|
+ :param weight: weight matrix
|
|
|
+ :param bias: an optional bias tensor
|
|
|
+ :param activation: Activation name. Needs to be a Triton kernel.
|
|
|
+ :param act_input: an optional tensor to save the activation inputs (for backward)
|
|
|
+ :return: result tensor
|
|
|
+ """
|
|
|
+ # if torch.is_autocast_enabled():
|
|
|
+ # dtype = torch.get_autocast_gpu_dtype()
|
|
|
+ # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
|
|
|
+
|
|
|
+ assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
|
|
+
|
|
|
+ batch_shape, n = x.shape[:-1], x.shape[-1]
|
|
|
+ batch_dim = batch_shape.numel()
|
|
|
+ x_reshaped = x.reshape(batch_dim, n)
|
|
|
+
|
|
|
+ if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
|
|
|
+ x_reshaped = x_reshaped.contiguous()
|
|
|
+ if weight.stride(0) > 1 and weight.stride(1) > 1:
|
|
|
+ weight = weight.contiguous()
|
|
|
+ bias = bias.contiguous() if bias is not None else None
|
|
|
+
|
|
|
+ assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
|
|
|
+ if bias is not None:
|
|
|
+ assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
|
|
|
+ assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
|
|
|
+
|
|
|
+ assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
|
|
|
+
|
|
|
+ M, K = x_reshaped.shape
|
|
|
+ N, K = weight.shape
|
|
|
+
|
|
|
+ output = torch.empty((M, N), device=x.device, dtype=x.dtype)
|
|
|
+ act_input = torch.empty_like(output) if save_act_input else None
|
|
|
+
|
|
|
+ # 1D launch kernel where each block gets its own program.
|
|
|
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
|
|
|
+
|
|
|
+ kernel_fwd[grid](
|
|
|
+ output,
|
|
|
+ act_input,
|
|
|
+ x_reshaped,
|
|
|
+ weight, # data ptrs
|
|
|
+ bias if bias is not None else x, # auto skip bias if not present
|
|
|
+ M, # shapes
|
|
|
+ N,
|
|
|
+ K,
|
|
|
+ M // 32, # key for triton cache (limit number of compilations)
|
|
|
+ N // 32,
|
|
|
+ K // 32,
|
|
|
+ stride_cm=output.stride(0), # strides
|
|
|
+ # stride_cn=output.stride(1),
|
|
|
+ stride_am=x_reshaped.stride(0),
|
|
|
+ stride_ak=x_reshaped.stride(1),
|
|
|
+ stride_bk=weight.stride(1),
|
|
|
+ stride_bn=weight.stride(0),
|
|
|
+ BIAS=bias is not None, # optional fused bias
|
|
|
+ SAVE_ACT_INPUT=save_act_input, # optional save activation inputs
|
|
|
+ ACTIVATION=activation, # optional fused activation
|
|
|
+ A_ROWMAJOR=x_reshaped.stride(1) == 1,
|
|
|
+ B_COLMAJOR=weight.stride(1) == 1,
|
|
|
+ GROUP_M=8, # speed optimization: group the programs
|
|
|
+ )
|
|
|
+
|
|
|
+ if not save_act_input:
|
|
|
+ return output.reshape(*batch_shape, output.shape[-1])
|
|
|
+ else:
|
|
|
+ return (output.reshape(*batch_shape, output.shape[-1]),
|
|
|
+ act_input.reshape(*batch_shape, act_input.shape[-1]))
|
|
|
+
|
|
|
+
|
|
|
+@triton.autotune(
|
|
|
+ configs=[
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
|
|
+ # good for int8
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
|
|
+ triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
|
|
+ triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
|
|
+ ]
|
|
|
+ + get_configs_io_bound(),
|
|
|
+ key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
|
|
+ prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
|
|
+)
|
|
|
+@triton.heuristics(
|
|
|
+ {
|
|
|
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
|
|
+ }
|
|
|
+)
|
|
|
+@triton.jit
|
|
|
+def kernel_bwd(
|
|
|
+ C, # Pointers to matrices
|
|
|
+ ACT_INPUT,
|
|
|
+ A,
|
|
|
+ B,
|
|
|
+ # Matrix dimensions
|
|
|
+ M,
|
|
|
+ N,
|
|
|
+ K,
|
|
|
+ CACHE_KEY_M,
|
|
|
+ CACHE_KEY_N,
|
|
|
+ CACHE_KEY_K,
|
|
|
+ # The stride variables represent how much to increase the ptr by when moving by 1
|
|
|
+ # element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
|
|
+ # by to get the element one row down (A has M rows)
|
|
|
+ stride_cm,
|
|
|
+ # stride_cn, # Assume that stride_cn == 1
|
|
|
+ stride_am,
|
|
|
+ stride_ak,
|
|
|
+ stride_bk,
|
|
|
+ stride_bn,
|
|
|
+ # Meta-parameters
|
|
|
+ BLOCK_M: tl.constexpr,
|
|
|
+ GROUP_M: tl.constexpr,
|
|
|
+ BLOCK_N: tl.constexpr,
|
|
|
+ BLOCK_K: tl.constexpr,
|
|
|
+ # split k not used, not performant with activation, kept because early_config_prune is expecting it
|
|
|
+ SPLIT_K: tl.constexpr,
|
|
|
+ EVEN_K: tl.constexpr,
|
|
|
+ ACTIVATION: tl.constexpr,
|
|
|
+):
|
|
|
+
|
|
|
+ """
|
|
|
+ Kernel for computing Out = activation(A x W + C)
|
|
|
+ - Input has shape (M, K)
|
|
|
+ - Weight has shape (K, N)
|
|
|
+ - Output has shape (M, N)
|
|
|
+ - ActInputs (optional) has shape (M, N)
|
|
|
+ 'ActInputs' optionally saves the A x W + C intermediate for backward computations
|
|
|
+ This kernel will consolidate over K
|
|
|
+ """
|
|
|
+
|
|
|
+ pid = tl.program_id(axis=0)
|
|
|
+
|
|
|
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
|
|
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
|
|
+ # re-order program ID for better L2 performance
|
|
|
+ width = GROUP_M * grid_n
|
|
|
+ group_id = pid // width
|
|
|
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
|
+ pid_m = group_id * GROUP_M + (pid % group_size)
|
|
|
+ pid_n = (pid % width) // (group_size)
|
|
|
+
|
|
|
+ # now compute the block that each program will go through
|
|
|
+ # rm (resp. rn) denotes a range of indices
|
|
|
+ # for rows (resp. col) of C
|
|
|
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+ # trick to avoid masking on M and N axis
|
|
|
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
|
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
|
+ rk = tl.arange(0, BLOCK_K)
|
|
|
+
|
|
|
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
|
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
|
+
|
|
|
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
+
|
|
|
+ for k in range(K, 0, -BLOCK_K):
|
|
|
+ if EVEN_K:
|
|
|
+ a = tl.load(A)
|
|
|
+ b = tl.load(B)
|
|
|
+ else:
|
|
|
+ a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
|
|
+ b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
|
|
+ acc += tl.dot(a, b)
|
|
|
+
|
|
|
+ A += BLOCK_K * stride_ak
|
|
|
+ B += BLOCK_K * stride_bk
|
|
|
+
|
|
|
+ # optional: fused activation (while the data is in shared memory)
|
|
|
+ if ACTIVATION != 'id':
|
|
|
+ act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
|
|
|
+ act_input = tl.load(act_in_ptrs).to(acc.dtype)
|
|
|
+ if ACTIVATION == "gelu":
|
|
|
+ acc *= gelu_grad(act_input)
|
|
|
+ elif ACTIVATION == "gelu_approx":
|
|
|
+ acc *= gelu_approx_grad(act_input)
|
|
|
+ elif ACTIVATION == "squared_relu":
|
|
|
+ acc *= squared_relu_grad(act_input)
|
|
|
+
|
|
|
+ # rematerialize rm and rn to save registers
|
|
|
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+
|
|
|
+ # write back result
|
|
|
+ C = C + rm[:, None] * stride_cm + rn[None, :]
|
|
|
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
|
+ tl.store(C, acc, mask=mask)
|
|
|
+
|
|
|
+
|
|
|
+def triton_dgrad_act(
|
|
|
+ grad_output: torch.Tensor,
|
|
|
+ weight: torch.Tensor,
|
|
|
+ activation: str = 'id',
|
|
|
+ act_input: Optional[torch.Tensor] = None,
|
|
|
+) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Compute e = activation(grad_output @ weight + bias).
|
|
|
+ This wrapper kicks the `kernel_fwd` Triton kernel
|
|
|
+ :param grad_output: input tensor
|
|
|
+ :param weight: weight matrix
|
|
|
+ :param activation: Activation name. Needs to be a Triton kernel.
|
|
|
+ :param act_input: an optional tensor to save the activation inputs (for backward)
|
|
|
+ :return: result tensor
|
|
|
+ """
|
|
|
+ assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
|
|
+
|
|
|
+ batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
|
|
|
+ batch_dim = batch_shape.numel()
|
|
|
+ grad_output_reshaped = grad_output.reshape(batch_dim, n)
|
|
|
+
|
|
|
+ if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
|
|
|
+ grad_output_reshaped = grad_output_reshaped.contiguous()
|
|
|
+ if weight.stride(0) > 1 and weight.stride(1) > 1:
|
|
|
+ weight = weight.contiguous()
|
|
|
+
|
|
|
+ assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
|
|
|
+ assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
|
|
|
+ if activation != 'id':
|
|
|
+ assert act_input is not None, f'act_input is required for activation {activation}'
|
|
|
+
|
|
|
+ # M, N, K in bwd are different from M, N, K in fwd
|
|
|
+ M, K = grad_output_reshaped.shape
|
|
|
+ K, N = weight.shape
|
|
|
+
|
|
|
+ grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)
|
|
|
+
|
|
|
+ # 1D launch kernel where each block gets its own program.
|
|
|
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
|
|
|
+
|
|
|
+ kernel_bwd[grid](
|
|
|
+ grad_input,
|
|
|
+ act_input,
|
|
|
+ grad_output_reshaped,
|
|
|
+ weight, # data ptrs
|
|
|
+ M, # shapes
|
|
|
+ N,
|
|
|
+ K,
|
|
|
+ M // 32, # key for triton cache (limit number of compilations)
|
|
|
+ N // 32,
|
|
|
+ K // 32,
|
|
|
+ stride_cm=grad_input.stride(0), # strides
|
|
|
+ # stride_cn=grad_input.stride(1),
|
|
|
+ stride_am=grad_output_reshaped.stride(0),
|
|
|
+ stride_ak=grad_output_reshaped.stride(1),
|
|
|
+ stride_bk=weight.stride(0),
|
|
|
+ stride_bn=weight.stride(1),
|
|
|
+ ACTIVATION=activation, # optional fused activation
|
|
|
+ GROUP_M=8, # speed optimization: group the programs
|
|
|
+ )
|
|
|
+
|
|
|
+ return grad_input.reshape(*batch_shape, grad_input.shape[-1])
|