123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import torch
- import triton
- import triton.language as tl
- from .utils import get_lora_op_configs
- @triton.jit
- def _bgmv_embed_kernel(
- tokens, # pointer to tokens array
- embed_tokens_all, # pointer to embedded tokens - all
- embed_tokens_base, # pointer to embedded tokens - base
- token_indices, # pointer to token indices
- embeddings, # pointer to output embeddings
- num_tokens, # number of tokens
- HIDDEN_DIM: tl.constexpr, # hidden dimension
- VOCAB_SIZE: tl.constexpr, # vocabulary size
- BLOCK_N: tl.constexpr # block size (number of tokens per block)
- ):
- # Calculate the starting index for this block
- start_idx = tl.program_id(0) * BLOCK_N
- # Create an array of offsets for the tokens in this block
- offs_n = start_idx + tl.arange(0, BLOCK_N)
- # Create a mask to handle cases where we exceed num_tokens
- mask = offs_n < num_tokens
- # Load lora_index and tokens for the current block (masked)
- lora_index = tl.load(token_indices + offs_n, mask=mask, other=-1)
- cur_tokens = tl.load(tokens + offs_n, mask=mask, other=0)
- # Compute offsets into the embedding matrices
- hidden_range = tl.arange(0, HIDDEN_DIM)
- offsets_embed = cur_tokens[:, None] * HIDDEN_DIM + hidden_range[
- None, :] # Shape: (BLOCK_N, HIDDEN_DIM)
- # Load embeddings from embed_tokens_base
- embeddings_base = tl.load(embed_tokens_base + offsets_embed,
- mask=mask[:, None],
- other=0.0)
- # Initialize embeddings_block with embeddings_base
- embeddings_block = embeddings_base
- # Create a mask for tokens that require loading from embed_tokens_all
- mask_all = (lora_index != -1) & mask
- # For tokens with lora_index != -1, load from embed_tokens_all
- # Calculate base offsets for tokens with lora_index != -1
- # Use tl.where to avoid invalid memory accesses
- base_offsets_all = tl.where(mask_all, lora_index * HIDDEN_DIM * VOCAB_SIZE,
- 0)
- # Calculate full offsets into embed_tokens_all
- full_offsets_all = base_offsets_all[:, None] + offsets_embed
- # Load embeddings from embed_tokens_all
- embeddings_all = tl.load(embed_tokens_all + full_offsets_all,
- mask=mask_all[:, None],
- other=0.0)
- # Overwrite embeddings_block where lora_index != -1
- embeddings_block = tl.where(mask_all[:, None], embeddings_all,
- embeddings_block)
- # Calculate the offsets where embeddings should be stored
- output_offsets = offs_n[:, None] * HIDDEN_DIM + hidden_range[None, :]
- # Store embeddings_block to the output embeddings array
- tl.store(embeddings + output_offsets, embeddings_block, mask=mask[:, None])
- @torch.inference_mode()
- def _bgmv_embed(
- tokens: torch.Tensor,
- embed_tokens_all: torch.Tensor,
- embed_tokens_base: torch.Tensor,
- token_indices: torch.Tensor,
- ) -> torch.Tensor:
- """
- Args:
- tokens - [num_tokens] - input tokens
- embed_tokens_all - [num_loras, vocab_size, hidden_dim]
- modules_to_save embeddings
- embed_tokens_base - [vocab_size, hidden_dim] - base layer
- embeddings will be applied to tokens with index=-1
- token_indices - [num_tokens] LoRA indices from 0 to num_loras,
- -1 means no LoRA, embed_tokens_base will be used
- returns:
- embeddings: [num_tokens, hidden_dim]
- """
- assert embed_tokens_all.dtype == embed_tokens_base.dtype
- assert tokens.dtype == torch.int64
- assert token_indices.dtype == torch.int64
- assert embed_tokens_base.is_contiguous()
- assert embed_tokens_all.is_contiguous()
- vocab_size, hidden_dim = embed_tokens_all.shape[-2:]
- num_tokens = tokens.shape[0]
- embeddings = torch.zeros((num_tokens, hidden_dim),
- dtype=embed_tokens_all.dtype,
- device=embed_tokens_all.device)
- grid = lambda meta: (triton.cdiv(num_tokens, meta['BLOCK_N']), )
- config = get_lora_op_configs("embed", num_tokens, hidden_dim)
- _bgmv_embed_kernel[grid](
- tokens,
- embed_tokens_all,
- embed_tokens_base,
- token_indices,
- embeddings,
- num_tokens,
- HIDDEN_DIM=hidden_dim,
- VOCAB_SIZE=vocab_size,
- **config,
- )
- return embeddings
- try:
- bgmv_embed = torch.library.custom_op("lora::bgmv_embed",
- _bgmv_embed,
- mutates_args=[])
- except AttributeError:
- bgmv_embed = _bgmv_embed
|