bgmv_sample.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from .utils import get_lora_op_configs
  5. @triton.jit
  6. def _bgmv_sample_kernel(hidden_state_ptr, lm_heads_all_ptr, lm_head_base_ptr,
  7. logits_ptr, sampling_indices_tensor_ptr,
  8. HIDDEN_DIM: tl.constexpr, VOCAB_SIZE: tl.constexpr,
  9. BLOCK_N: tl.constexpr):
  10. cur_token = tl.program_id(axis=0)
  11. logits_start_idx = tl.program_id(axis=1) * BLOCK_N
  12. lora_index = tl.load(sampling_indices_tensor_ptr + cur_token)
  13. hidden_state = tl.load(hidden_state_ptr + HIDDEN_DIM * cur_token +
  14. tl.arange(0, HIDDEN_DIM))
  15. hidden_state = hidden_state.expand_dims(0)
  16. offsets_embed = tl.arange(0, HIDDEN_DIM)
  17. offsets_logits = logits_start_idx + tl.arange(0, BLOCK_N)
  18. offset_base_layer = offsets_embed[
  19. None, :] + offsets_logits[:, None] * HIDDEN_DIM
  20. offset_lora = lora_index * (VOCAB_SIZE * HIDDEN_DIM) + offset_base_layer
  21. if lora_index == -1:
  22. weights = tl.load(lm_head_base_ptr + offset_base_layer)
  23. else:
  24. weights = tl.load(lm_heads_all_ptr + offset_lora)
  25. logits = tl.sum(weights * hidden_state, axis=1)
  26. tl.store(logits_ptr + cur_token * VOCAB_SIZE + offsets_logits, logits)
  27. @torch.inference_mode()
  28. def _bgmv_sample(
  29. hidden_state: torch.Tensor,
  30. lm_heads_all: torch.Tensor,
  31. lm_head_base: torch.Tensor,
  32. sampling_indices_tensor: torch.Tensor,
  33. ) -> torch.Tensor:
  34. """
  35. Args:
  36. hidden_state - [num_tokens, hidden_dim]
  37. lm_heads_all - [num_loras, vocab_size, hidden_dim]
  38. sampling_indices_tensor - [num_tokens] - indexes from 0 to num_loras-1
  39. """
  40. assert hidden_state.dtype == lm_heads_all.dtype
  41. assert hidden_state.size(-1) == lm_heads_all.size(-1)
  42. assert hidden_state.is_contiguous()
  43. assert lm_heads_all.is_contiguous()
  44. vocab_size = lm_heads_all.shape[-2]
  45. logits = torch.zeros((hidden_state.size(0), vocab_size),
  46. dtype=hidden_state.dtype,
  47. device=hidden_state.device)
  48. num_tokens = sampling_indices_tensor.shape[0]
  49. hidden_dim = hidden_state.shape[-1]
  50. grid = lambda meta: (num_tokens, triton.cdiv(vocab_size, meta['BLOCK_N']))
  51. config = get_lora_op_configs("sample", num_tokens, hidden_dim)
  52. _bgmv_sample_kernel[grid](
  53. hidden_state,
  54. lm_heads_all,
  55. lm_head_base,
  56. logits,
  57. sampling_indices_tensor,
  58. HIDDEN_DIM=hidden_dim,
  59. VOCAB_SIZE=vocab_size,
  60. **config,
  61. )
  62. return logits
  63. try:
  64. bgmv_sample = torch.library.custom_op("lora::bgmv_sample",
  65. _bgmv_sample,
  66. mutates_args=[])
  67. except AttributeError:
  68. bgmv_sample = _bgmv_sample