123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- import torch
- import math
- from .utils import DEBUG
- def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
- if DEBUG:
- print()
- print("attention_forward_core_ref_impl")
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("sm_scale:", sm_scale)
- print("causal:", causal)
- print("use_exp2:", use_exp2)
-
- # Compute attention scores
- attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
- if DEBUG:
- print("attention_scores:", attention_scores, attention_scores.shape)
- # Scale scores
- attention_scaled_scores = sm_scale * attention_scores
- if DEBUG:
- print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape)
- # Apply causal mask if necessary
- if causal:
- L_q, L_k = q.shape[1], k.shape[1]
- row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
- col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
- col_offset = L_q-L_k
- causal_mask = row_idx >= (col_offset + col_idx)
- if DEBUG:
- print("causal_mask:", causal_mask)
- # set -inf to places the causal mask is false
- attention_scaled_scores = attention_scaled_scores.masked_fill(
- torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
- )
- if DEBUG:
- print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape)
- # Compute max for numerical stability
- max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0]
- if DEBUG:
- print("max_scores:", max_scores, max_scores.shape)
- if causal:
- # Replace -inf in max_scores with zeros to avoid NaN in subtraction
- max_scores = torch.where(
- torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores
- )
- if DEBUG:
- print("max_scores if causal:", max_scores, max_scores.shape)
- # Shift scores
- attention_shifted_scaled_scores = attention_scaled_scores - max_scores
- if DEBUG:
- print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape)
- # Exponentiate
- if use_exp2:
- RCP_LN = 1 / math.log(2)
- exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores)
- else:
- exp_scores = torch.exp(attention_shifted_scaled_scores)
- if DEBUG:
- print("exp_scores:", exp_scores, exp_scores.shape)
- # Sum of exponentials
- sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True)
- if DEBUG:
- print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape)
- if causal:
- # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly
- sum_exp_scores = torch.where(
- sum_exp_scores == 0,
- torch.ones_like(sum_exp_scores),
- sum_exp_scores
- )
- if DEBUG:
- print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape)
- # Compute softmax probabilities
- softmax = exp_scores / sum_exp_scores
- if DEBUG:
- print("softmax:", softmax, softmax.shape)
- # Compute log-sum-exp
- if use_exp2:
- LN2 = math.log(2)
- RCP_LN = 1 / math.log(2)
- max_scores_base2 = max_scores * RCP_LN
- softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores)
- softmax_lse = softmax_lse_base2 * LN2
- softmax_lse.squeeze_(-1)
- else:
- softmax_lse = max_scores + torch.log(sum_exp_scores)
- softmax_lse = softmax_lse.squeeze(-1)
- if DEBUG:
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- # Compute output
- o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
- if DEBUG:
- print("o:", o, o.shape)
- return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
- def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
- """Compute reference output and softmax_lse using PyTorch's built-in function"""
- # Ensure the layout is 'bhsd'
- if layout == "bshd":
- q = q.transpose(1, 2).contiguous()
- k = k.transpose(1, 2).contiguous()
- v = v.transpose(1, 2).contiguous()
- elif layout != "bhsd":
- raise ValueError(f"Unknown layout {layout}")
- # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
- batch_size, num_heads, seq_len_q, head_dim = q.shape
- seq_len_k = k.shape[2]
- # Merge batch and heads dimensions
- q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
- k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
- v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
- # Call the core attention function
- o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl(
- q, k, v, sm_scale, causal, use_exp2
- )
- # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
- o = o.reshape(batch_size, num_heads, seq_len_q, head_dim)
- softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q)
- exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
- softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
- attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
- attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
- attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
- # Restore original layout if necessary
- if layout == "bshd":
- o = o.transpose(1, 2)
- return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
- def attention_varlen_forward_pytorch_ref_impl(
- q,
- k,
- v,
- sm_scale,
- causal,
- layout,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- use_exp2
- ):
- # Ensure the layout is 'thd'
- if layout != 'thd':
- raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
- batch_size = cu_seqlens_q.shape[0] - 1
- num_heads = q.shape[1]
- head_dim = q.shape[2]
- # Pre-allocate outputs
- total_L_q = q.shape[0]
- total_L_k = k.shape[0]
- o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device)
- softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device)
- for i in range(batch_size):
- # Get the start and end indices for the current sequence
- start_q = cu_seqlens_q[i].item()
- end_q = cu_seqlens_q[i + 1].item()
- start_k = cu_seqlens_k[i].item()
- end_k = cu_seqlens_k[i + 1].item()
- # Extract q_i, k_i, v_i
- q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
- k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
- v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
- # Permute to [num_heads, L_q_i, head_dim]
- q_i = q_i.permute(1, 0, 2)
- k_i = k_i.permute(1, 0, 2)
- v_i = v_i.permute(1, 0, 2)
- # Call the core attention function for this sequence
- (
- o_i,
- softmax_lse_i,
- exp_scores_i,
- softmax_i,
- attention_shifted_scaled_scores_i,
- attention_scaled_scores_i,
- attention_scores_i,
- ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2)
- # Convert back to 'thd' layout and float16
- o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim]
- # Place outputs in pre-allocated tensors
- o[start_q:end_q, :, :] = o_i
- softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads]
- # For variable-sized outputs, map them into the preallocated tensors
- # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i]
- exp_scores_i = exp_scores_i.permute(1, 0, 2)
- softmax_i = softmax_i.permute(1, 0, 2)
- attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2)
- attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2)
- attention_scores_i = attention_scores_i.permute(1, 0, 2)
- return (
- o,
- softmax_lse,
- None,
- None,
- None,
- None,
- None,
- )
- def attention_forward_pytorch_ref_impl(
- q,
- k,
- v,
- sm_scale,
- causal,
- layout,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- use_exp2
- ):
- if DEBUG:
- print()
- print("attention_forward_pytorch_ref_impl")
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("sm_scale:", sm_scale)
- print("causal:", causal)
- print("cu_seqlens_q:", cu_seqlens_q)
- print("cu_seqlens_k:", cu_seqlens_k)
- print("max_seqlen_q:", max_seqlen_q)
- print("max_seqlen_k:", max_seqlen_k)
- print("use_exp2:", use_exp2)
- # compute reference
- if layout == "thd":
- (
- o_ref,
- softmax_lse_ref,
- exp_scores_ref,
- softmax_ref,
- attention_shifted_scaled_scores_ref,
- attention_scaled_scores_ref,
- attention_scores_ref,
- ) = attention_varlen_forward_pytorch_ref_impl(
- q.clone(),
- k.clone(),
- v.clone(),
- sm_scale,
- causal,
- layout,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- use_exp2,
- )
- else:
- (
- o_ref,
- softmax_lse_ref,
- exp_scores_ref,
- softmax_ref,
- attention_shifted_scaled_scores_ref,
- attention_scaled_scores_ref,
- attention_scores_ref,
- ) = attention_vanilla_forward_pytorch_ref_impl(
- q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
- )
- if DEBUG:
- print()
- print("attention_forward_pytorch_ref_impl outputs")
- print("o_ref:", o_ref, o_ref.shape)
- print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
- print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None)
- return (
- o_ref,
- softmax_lse_ref,
- exp_scores_ref,
- softmax_ref,
- attention_shifted_scaled_scores_ref,
- attention_scaled_scores_ref,
- attention_scores_ref,
- )
- def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
- q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
- k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
- relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
- return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
|