rotary.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Optional, Union
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. @triton.jit
  7. def rotary_kernel(
  8. OUT, # Pointers to matrices
  9. X,
  10. COS,
  11. SIN,
  12. CU_SEQLENS,
  13. SEQLEN_OFFSETS, # this could be int or a pointer
  14. # Matrix dimensions
  15. seqlen,
  16. rotary_dim,
  17. seqlen_ro,
  18. # strides
  19. stride_out_batch,
  20. stride_out_seqlen,
  21. stride_out_nheads,
  22. stride_out_headdim,
  23. stride_x_batch,
  24. stride_x_seqlen,
  25. stride_x_nheads,
  26. stride_x_headdim,
  27. # Meta-parameters
  28. BLOCK_K: tl.constexpr,
  29. IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
  30. IS_VARLEN: tl.constexpr,
  31. INTERLEAVED: tl.constexpr,
  32. CONJUGATE: tl.constexpr,
  33. BLOCK_M: tl.constexpr,
  34. ):
  35. pid_m = tl.program_id(axis=0)
  36. pid_batch = tl.program_id(axis=1)
  37. pid_head = tl.program_id(axis=2)
  38. rotary_dim_half = rotary_dim // 2
  39. if not IS_VARLEN:
  40. X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
  41. OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
  42. else:
  43. start_idx = tl.load(CU_SEQLENS + pid_batch)
  44. seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
  45. X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
  46. OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
  47. if pid_m * BLOCK_M >= seqlen:
  48. return
  49. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  50. if not IS_SEQLEN_OFFSETS_TENSOR:
  51. rm_cs = rm + SEQLEN_OFFSETS
  52. else:
  53. rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
  54. rk = tl.arange(0, BLOCK_K)
  55. rk_half = tl.arange(0, BLOCK_K // 2)
  56. if not INTERLEAVED:
  57. # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
  58. X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
  59. COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
  60. SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
  61. cos = tl.load(
  62. COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
  63. ).to(tl.float32)
  64. sin = tl.load(
  65. SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
  66. ).to(tl.float32)
  67. x0 = tl.load(
  68. X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
  69. ).to(tl.float32)
  70. x1 = tl.load(
  71. X + rotary_dim_half * stride_x_headdim,
  72. mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
  73. other=0.0,
  74. ).to(tl.float32)
  75. if CONJUGATE:
  76. sin = -sin
  77. o0 = x0 * cos - x1 * sin
  78. o1 = x0 * sin + x1 * cos
  79. # write back result
  80. OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
  81. tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
  82. tl.store(
  83. OUT + rotary_dim_half * stride_out_headdim,
  84. o1,
  85. mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
  86. )
  87. else:
  88. # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
  89. # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
  90. # Loading x0 will be fast but x1 will be slow.
  91. # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
  92. # Then we do the calculation and use tl.where to pick put the right outputs for the even
  93. # and for the odd indices.
  94. rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
  95. rk_repeat = tl.arange(0, BLOCK_K) // 2
  96. X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
  97. X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
  98. COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
  99. SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
  100. cos = tl.load(
  101. COS,
  102. mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
  103. other=1.0,
  104. ).to(tl.float32)
  105. sin = tl.load(
  106. SIN,
  107. mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
  108. other=0.0,
  109. ).to(tl.float32)
  110. x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
  111. tl.float32
  112. )
  113. x1 = tl.load(
  114. X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
  115. ).to(tl.float32)
  116. if CONJUGATE:
  117. sin = -sin
  118. x0_cos = x0 * cos
  119. x1_sin = x1 * sin
  120. out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
  121. OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
  122. tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
  123. def apply_rotary(
  124. x: torch.Tensor,
  125. cos: torch.Tensor,
  126. sin: torch.Tensor,
  127. seqlen_offsets: Union[int, torch.Tensor] = 0,
  128. cu_seqlens: Optional[torch.Tensor] = None,
  129. max_seqlen: Optional[int] = None,
  130. interleaved=False,
  131. inplace=False,
  132. conjugate=False,
  133. ) -> torch.Tensor:
  134. """
  135. Arguments:
  136. x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
  137. else (total_seqlen, nheads, headdim).
  138. cos: (seqlen_ro, rotary_dim / 2)
  139. sin: (seqlen_ro, rotary_dim / 2)
  140. seqlen_offsets: integer or integer tensor of size (batch,)
  141. cu_seqlens: (batch + 1,) or None
  142. max_seqlen: int
  143. Returns:
  144. y: (batch, seqlen, nheads, headdim)
  145. """
  146. is_varlen = cu_seqlens is not None
  147. if not is_varlen:
  148. batch, seqlen, nheads, headdim = x.shape
  149. else:
  150. assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
  151. total_seqlen, nheads, headdim = x.shape
  152. batch_p_1 = cu_seqlens.shape[0]
  153. batch = batch_p_1 - 1
  154. seqlen = max_seqlen
  155. seqlen_ro, rotary_dim = cos.shape
  156. assert sin.shape == cos.shape
  157. rotary_dim *= 2
  158. assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
  159. assert headdim <= 256, "Only support headdim <= 256"
  160. assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
  161. assert (
  162. cos.dtype == sin.dtype
  163. ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
  164. assert (
  165. x.dtype == cos.dtype
  166. ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
  167. cos, sin = cos.contiguous(), sin.contiguous()
  168. if isinstance(seqlen_offsets, torch.Tensor):
  169. assert seqlen_offsets.shape == (batch,)
  170. assert seqlen_offsets.dtype in [torch.int32, torch.int64]
  171. seqlen_offsets = seqlen_offsets.contiguous()
  172. else:
  173. assert seqlen_offsets + seqlen <= seqlen_ro
  174. output = torch.empty_like(x) if not inplace else x
  175. if rotary_dim < headdim and not inplace:
  176. output[..., rotary_dim:].copy_(x[..., rotary_dim:])
  177. BLOCK_K = (
  178. 32
  179. if rotary_dim <= 32
  180. else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
  181. )
  182. grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
  183. BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
  184. # Need this, otherwise Triton tries to launch from cuda:0 and we get
  185. # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
  186. with torch.cuda.device(x.device.index):
  187. rotary_kernel[grid](
  188. output, # data ptrs
  189. x,
  190. cos,
  191. sin,
  192. cu_seqlens,
  193. seqlen_offsets,
  194. seqlen, # shapes
  195. rotary_dim,
  196. seqlen_ro,
  197. output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
  198. output.stride(-3), # seqlen_stride or total_seqlen_stride
  199. output.stride(-2), # nheads_stride
  200. output.stride(-1), # headdim_stride
  201. x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
  202. x.stride(-3), # seqlen stride or total_seqlen_stride
  203. x.stride(-2), # nheads stride
  204. x.stride(-1), # headdim stride
  205. BLOCK_K,
  206. isinstance(seqlen_offsets, torch.Tensor),
  207. is_varlen,
  208. interleaved,
  209. conjugate,
  210. BLOCK_M,
  211. )
  212. return output