|
@@ -139,22 +139,37 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
sin_k=None,
|
|
|
interleaved=False,
|
|
|
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
|
|
+ num_heads_q: Union[int] = None,
|
|
|
):
|
|
|
- batch, seqlen, three, nheads, headdim = qkv.shape
|
|
|
- assert three == 3
|
|
|
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
|
|
# Call 1 kernel instead of 2 kernels
|
|
|
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
|
|
# dimensions, we get the same tensor
|
|
|
- # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
|
|
- qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
|
|
+ if qkv.dim() == 5:
|
|
|
+ batch, seqlen, three, nheads, headdim = qkv.shape
|
|
|
+ assert three == 3
|
|
|
+ # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
|
|
|
+ qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
|
|
|
+ else:
|
|
|
+ assert qkv.dim() == 4
|
|
|
+ assert num_heads_q is not None
|
|
|
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
|
|
+ assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
|
|
|
+ qk = qkv[:, :, :num_heads_q + num_heads_k]
|
|
|
apply_rotary(
|
|
|
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
|
|
|
)
|
|
|
else:
|
|
|
cos_k = cos if cos_k is None else cos_k
|
|
|
sin_k = sin if sin_k is None else sin_k
|
|
|
- q, k = qkv[:, :, 0], qkv[:, :, 1]
|
|
|
+ if qkv.dim() == 5:
|
|
|
+ q, k = qkv[:, :, 0], qkv[:, :, 1]
|
|
|
+ else:
|
|
|
+ assert qkv.dim() == 4
|
|
|
+ assert num_heads_q is not None
|
|
|
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
|
|
|
+ assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
|
|
|
+ q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]
|
|
|
apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
|
|
|
apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
|
|
|
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
|
@@ -165,6 +180,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
|
|
|
ctx.seqlen_offsets = None
|
|
|
ctx.interleaved = interleaved
|
|
|
+ ctx.num_heads_q = num_heads_q
|
|
|
return qkv
|
|
|
|
|
|
@staticmethod
|
|
@@ -178,7 +194,14 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
# Call 1 kernel instead of 2 kernels
|
|
|
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
|
|
|
# dimensions, we get the same tensor
|
|
|
- dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
|
|
+ if dqkv.dim() == 5:
|
|
|
+ dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
|
|
|
+ else:
|
|
|
+ assert dqkv.dim() == 4
|
|
|
+ assert ctx.num_heads_q is not None
|
|
|
+ num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
|
|
|
+ assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
|
|
|
+ dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
|
|
|
apply_rotary(
|
|
|
dqk,
|
|
|
cos,
|
|
@@ -191,9 +214,23 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
else:
|
|
|
cos_k = cos if cos_k is None else cos_k
|
|
|
sin_k = sin if sin_k is None else sin_k
|
|
|
- dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
|
|
+ if dqkv.dim() == 5:
|
|
|
+ dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
|
|
|
+ else:
|
|
|
+ assert dqkv.dim() == 4
|
|
|
+ assert ctx.num_heads_q is not None
|
|
|
+ num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
|
|
|
+ assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
|
|
|
+ dq = dqkv[:, :, : ctx.num_heads_q]
|
|
|
+ dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
|
|
|
apply_rotary(
|
|
|
- dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
|
|
|
+ dq,
|
|
|
+ cos,
|
|
|
+ sin,
|
|
|
+ seqlen_offsets,
|
|
|
+ interleaved=ctx.interleaved,
|
|
|
+ inplace=True,
|
|
|
+ conjugate=True,
|
|
|
)
|
|
|
apply_rotary(
|
|
|
dk,
|
|
@@ -204,7 +241,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
inplace=True,
|
|
|
conjugate=True,
|
|
|
)
|
|
|
- return dqkv, None, None, None, None, None, None
|
|
|
+ return dqkv, None, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
def apply_rotary_emb_qkv_(
|
|
@@ -215,10 +252,13 @@ def apply_rotary_emb_qkv_(
|
|
|
sin_k=None,
|
|
|
interleaved=False,
|
|
|
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
|
|
+ num_heads_q: Optional[int] = None,
|
|
|
):
|
|
|
"""
|
|
|
Arguments:
|
|
|
- qkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
|
+ qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
|
|
|
+ If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
|
|
|
+ then num_heads_q must be provided.
|
|
|
cos, sin: (seqlen, rotary_dim / 2)
|
|
|
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
|
|
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
|
@@ -226,11 +266,13 @@ def apply_rotary_emb_qkv_(
|
|
|
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
|
|
|
Most commonly used in inference when we have KV cache.
|
|
|
Return:
|
|
|
- qkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
|
+ qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
|
|
|
rotary_dim must be <= headdim
|
|
|
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
|
|
|
"""
|
|
|
- return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
|
|
|
+ return ApplyRotaryEmbQKV_.apply(
|
|
|
+ qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q
|
|
|
+ )
|
|
|
|
|
|
|
|
|
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
|
@@ -417,10 +459,13 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
kv: Optional[torch.Tensor] = None,
|
|
|
seqlen_offset: Union[int, torch.Tensor] = 0,
|
|
|
max_seqlen: Optional[int] = None,
|
|
|
+ num_heads_q: Optional[int] = None,
|
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
"""
|
|
|
- qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
|
|
- else it's just q of shape (batch, seqlen, nheads, headdim)
|
|
|
+ qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
|
|
|
+ if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
|
|
|
+ If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
|
|
|
+ then num_heads_q must be provided.
|
|
|
kv: (batch, seqlen, 2, nheads, headdim)
|
|
|
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
|
|
Most commonly used in inference when we have KV cache.
|
|
@@ -441,6 +486,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
self._sin_cached,
|
|
|
interleaved=self.interleaved,
|
|
|
seqlen_offsets=seqlen_offset,
|
|
|
+ num_heads_q=num_heads_q,
|
|
|
)
|
|
|
else:
|
|
|
return apply_rotary_emb_qkv_(
|
|
@@ -451,6 +497,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
self._sin_k_cached,
|
|
|
interleaved=self.interleaved,
|
|
|
seqlen_offsets=seqlen_offset,
|
|
|
+ num_heads_q=num_heads_q,
|
|
|
)
|
|
|
else:
|
|
|
q = qkv
|