bgmv_embed.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from .utils import get_lora_op_configs
  5. @triton.jit
  6. def _bgmv_embed_kernel(
  7. tokens, # pointer to tokens array
  8. embed_tokens_all, # pointer to embedded tokens - all
  9. embed_tokens_base, # pointer to embedded tokens - base
  10. token_indices, # pointer to token indices
  11. embeddings, # pointer to output embeddings
  12. num_tokens, # number of tokens
  13. HIDDEN_DIM: tl.constexpr, # hidden dimension
  14. VOCAB_SIZE: tl.constexpr, # vocabulary size
  15. BLOCK_N: tl.constexpr # block size (number of tokens per block)
  16. ):
  17. # Calculate the starting index for this block
  18. start_idx = tl.program_id(0) * BLOCK_N
  19. # Create an array of offsets for the tokens in this block
  20. offs_n = start_idx + tl.arange(0, BLOCK_N)
  21. # Create a mask to handle cases where we exceed num_tokens
  22. mask = offs_n < num_tokens
  23. # Load lora_index and tokens for the current block (masked)
  24. lora_index = tl.load(token_indices + offs_n, mask=mask, other=-1)
  25. cur_tokens = tl.load(tokens + offs_n, mask=mask, other=0)
  26. # Compute offsets into the embedding matrices
  27. hidden_range = tl.arange(0, HIDDEN_DIM)
  28. offsets_embed = cur_tokens[:, None] * HIDDEN_DIM + hidden_range[
  29. None, :] # Shape: (BLOCK_N, HIDDEN_DIM)
  30. # Load embeddings from embed_tokens_base
  31. embeddings_base = tl.load(embed_tokens_base + offsets_embed,
  32. mask=mask[:, None],
  33. other=0.0)
  34. # Initialize embeddings_block with embeddings_base
  35. embeddings_block = embeddings_base
  36. # Create a mask for tokens that require loading from embed_tokens_all
  37. mask_all = (lora_index != -1) & mask
  38. # For tokens with lora_index != -1, load from embed_tokens_all
  39. # Calculate base offsets for tokens with lora_index != -1
  40. # Use tl.where to avoid invalid memory accesses
  41. base_offsets_all = tl.where(mask_all, lora_index * HIDDEN_DIM * VOCAB_SIZE,
  42. 0)
  43. # Calculate full offsets into embed_tokens_all
  44. full_offsets_all = base_offsets_all[:, None] + offsets_embed
  45. # Load embeddings from embed_tokens_all
  46. embeddings_all = tl.load(embed_tokens_all + full_offsets_all,
  47. mask=mask_all[:, None],
  48. other=0.0)
  49. # Overwrite embeddings_block where lora_index != -1
  50. embeddings_block = tl.where(mask_all[:, None], embeddings_all,
  51. embeddings_block)
  52. # Calculate the offsets where embeddings should be stored
  53. output_offsets = offs_n[:, None] * HIDDEN_DIM + hidden_range[None, :]
  54. # Store embeddings_block to the output embeddings array
  55. tl.store(embeddings + output_offsets, embeddings_block, mask=mask[:, None])
  56. @torch.inference_mode()
  57. def _bgmv_embed(
  58. tokens: torch.Tensor,
  59. embed_tokens_all: torch.Tensor,
  60. embed_tokens_base: torch.Tensor,
  61. token_indices: torch.Tensor,
  62. ) -> torch.Tensor:
  63. """
  64. Args:
  65. tokens - [num_tokens] - input tokens
  66. embed_tokens_all - [num_loras, vocab_size, hidden_dim]
  67. modules_to_save embeddings
  68. embed_tokens_base - [vocab_size, hidden_dim] - base layer
  69. embeddings will be applied to tokens with index=-1
  70. token_indices - [num_tokens] LoRA indices from 0 to num_loras,
  71. -1 means no LoRA, embed_tokens_base will be used
  72. returns:
  73. embeddings: [num_tokens, hidden_dim]
  74. """
  75. assert embed_tokens_all.dtype == embed_tokens_base.dtype
  76. assert tokens.dtype == torch.int64
  77. assert token_indices.dtype == torch.int64
  78. assert embed_tokens_base.is_contiguous()
  79. assert embed_tokens_all.is_contiguous()
  80. vocab_size, hidden_dim = embed_tokens_all.shape[-2:]
  81. num_tokens = tokens.shape[0]
  82. embeddings = torch.zeros((num_tokens, hidden_dim),
  83. dtype=embed_tokens_all.dtype,
  84. device=embed_tokens_all.device)
  85. grid = lambda meta: (triton.cdiv(num_tokens, meta['BLOCK_N']), )
  86. config = get_lora_op_configs("embed", num_tokens, hidden_dim)
  87. _bgmv_embed_kernel[grid](
  88. tokens,
  89. embed_tokens_all,
  90. embed_tokens_base,
  91. token_indices,
  92. embeddings,
  93. num_tokens,
  94. HIDDEN_DIM=hidden_dim,
  95. VOCAB_SIZE=vocab_size,
  96. **config,
  97. )
  98. return embeddings
  99. try:
  100. bgmv_embed = torch.library.custom_op("lora::bgmv_embed",
  101. _bgmv_embed,
  102. mutates_args=[])
  103. except AttributeError:
  104. bgmv_embed = _bgmv_embed