123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- """
- This script is mainly used to test whether trtion kernels can run normally
- under different conditions, including various batches, numbers of LoRA , and
- maximum ranks.
- """
- import random
- from unittest.mock import patch
- import pytest
- import torch
- from aphrodite.lora.ops.bgmv_expand import bgmv_expand
- from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
- from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
- from aphrodite.lora.ops.sgmv_expand import sgmv_expand
- from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
- from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
- from aphrodite.triton_utils.libentry import LibEntry
- from .utils import (generate_data, generate_data_for_expand_nslices,
- ref_torch_groupgemm)
- HIDDEN_SIZES = [4097]
- BATCHES = [1, 4, 16, 32]
- NUM_LORA = [1, 8, 32, 128]
- DTYPES = [torch.float16, torch.bfloat16]
- MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
- SCALES = [0.5]
- SEED = [0]
- CUDA_DEVICES = [f"cuda:{0}"]
- def assert_close(a, b):
- rtol, atol = {
- torch.float16: (6e-2, 6e-2),
- torch.bfloat16: (6e-2, 6e-2),
- torch.float32: (1e-2, 1e-2),
- }[a.dtype]
- torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
- @pytest.mark.parametrize("batches", BATCHES)
- @pytest.mark.parametrize("num_loras", NUM_LORA)
- @pytest.mark.parametrize("rank", MAX_RANKS)
- @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
- @pytest.mark.parametrize("scaling", SCALES)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("op_type", ["shrink", "expand"])
- @pytest.mark.parametrize("seed", SEED)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_punica_sgmv(
- batches: int,
- num_loras: int,
- rank: int,
- hidden_size: int,
- scaling: float,
- dtype: torch.dtype,
- op_type: str,
- seed: int,
- device: str,
- ):
- random.seed(seed)
- torch.set_default_device(device)
- torch.random.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- seq_length = 128
- (
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- ref_out_tensor,
- b_seq_start_loc,
- lora_indices_tensor,
- seq_len_tensor,
- indices,
- ) = generate_data(
- batches,
- hidden_size,
- num_loras,
- rank,
- seq_length,
- dtype,
- op_type,
- device,
- )
- max_seq_length = seq_len_tensor.max()
- if isinstance(max_seq_length, tuple):
- max_seq_length = max_seq_length[0].item()
- else:
- max_seq_length = max_seq_length.item()
- if op_type == "shrink":
- sgmv_shrink(
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- b_seq_start_loc,
- seq_len_tensor,
- lora_indices_tensor,
- batches,
- max_seq_length,
- scaling,
- )
- else:
- sgmv_expand(
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- b_seq_start_loc,
- seq_len_tensor,
- lora_indices_tensor,
- batches,
- max_seq_length,
- add_inputs=True,
- )
- ref_torch_groupgemm(
- ref_out_tensor,
- inputs_tensor,
- lora_weights,
- lora_indices_tensor,
- seq_len_tensor,
- batches,
- scaling if op_type == "shrink" else 1.0,
- op_type,
- )
- if op_type == "shrink":
- ref_out_tensor = ref_out_tensor.to(torch.float32)
- assert_close(our_out_tensor, ref_out_tensor)
- @pytest.mark.parametrize("batches", BATCHES)
- @pytest.mark.parametrize("num_loras", NUM_LORA)
- @pytest.mark.parametrize("rank", MAX_RANKS)
- @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
- @pytest.mark.parametrize("scaling", SCALES)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("op_type", ["shrink", "expand"])
- @pytest.mark.parametrize("seed", SEED)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_punica_bgmv(
- batches: int,
- num_loras: int,
- rank: int,
- hidden_size: int,
- scaling: float,
- dtype: torch.dtype,
- op_type: str,
- seed: int,
- device: str,
- ):
- from aphrodite.lora.ops.bgmv_expand import _bgmv_expand_kernel
- from aphrodite.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
- random.seed(seed)
- torch.set_default_device(device)
- torch.random.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- seq_length = 1
- (
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- ref_out_tensor,
- b_seq_start_loc,
- lora_indices_tensor,
- seq_len_tensor,
- indices,
- ) = generate_data(
- batches,
- hidden_size,
- num_loras,
- rank,
- seq_length,
- dtype,
- op_type,
- device,
- )
- if op_type == "shrink":
- # The current _bgmv_shrink_kernel does not require the libentry
- # decoration. The purpose of adding this patch is to test the
- # correctness of libentry.
- with patch(
- "aphrodite.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
- LibEntry(_bgmv_shrink_kernel),
- ):
- bgmv_shrink(
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- indices,
- scaling,
- )
- else:
- # ditto
- with patch(
- "aphrodite.lora.ops.bgmv_expand._bgmv_expand_kernel",
- LibEntry(_bgmv_expand_kernel),
- ):
- bgmv_expand(
- inputs_tensor,
- lora_weights,
- our_out_tensor,
- indices,
- add_inputs=True,
- )
- ref_torch_groupgemm(
- ref_out_tensor,
- inputs_tensor,
- lora_weights,
- lora_indices_tensor,
- seq_len_tensor,
- batches,
- scaling if op_type == "shrink" else 1.0,
- op_type,
- )
- if op_type == "shrink":
- ref_out_tensor = ref_out_tensor.to(torch.float32)
- assert_close(our_out_tensor, ref_out_tensor)
- @pytest.mark.parametrize("batches", BATCHES)
- @pytest.mark.parametrize("num_loras", NUM_LORA)
- @pytest.mark.parametrize("rank", MAX_RANKS)
- @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
- @pytest.mark.parametrize("nslices", [2, 3])
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
- @pytest.mark.parametrize("seed", SEED)
- @pytest.mark.parametrize("device", CUDA_DEVICES)
- def test_punica_expand_nslices(
- batches: int,
- num_loras: int,
- rank: int,
- hidden_size: int,
- nslices: int,
- dtype: torch.dtype,
- op_type: str,
- seed: int,
- device: str,
- ):
- from aphrodite.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
- random.seed(seed)
- torch.set_default_device(device)
- torch.random.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- seq_length = 128 if op_type == "sgmv" else 1
- (
- inputs_tensor,
- lora_weights_lst,
- our_outputs,
- ref_outputs,
- b_seq_start_loc,
- lora_indices_tensor,
- seq_len_tensor,
- indices,
- ) = generate_data_for_expand_nslices(
- batches,
- hidden_size,
- num_loras,
- rank,
- seq_length,
- dtype,
- nslices,
- device,
- )
- max_seq_length = seq_len_tensor.max()
- if isinstance(max_seq_length, tuple):
- max_seq_length = max_seq_length[0].item()
- else:
- max_seq_length = max_seq_length.item()
- slice_offset = 0
- for index in range(nslices):
- lora_weights = lora_weights_lst[index]
- if op_type == "sgmv":
- sgmv_expand_slice(
- inputs_tensor,
- lora_weights,
- our_outputs,
- b_seq_start_loc,
- seq_len_tensor,
- lora_indices_tensor,
- batches,
- max_seq_length,
- slice_offset,
- hidden_size,
- add_inputs=True,
- )
- else:
- # The current _bgmv_expand_slice_kernel does not require the
- # libentry decoration. The purpose of adding this patch is to test
- # the correctness of libentry.
- with patch(
- "aphrodite.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
- LibEntry(_bgmv_expand_slice_kernel),
- ):
- bgmv_expand_slice(
- inputs_tensor,
- lora_weights,
- our_outputs,
- indices,
- slice_offset,
- slice_size=hidden_size,
- add_inputs=True,
- )
- ref_torch_groupgemm(
- ref_outputs[:, slice_offset:slice_offset + hidden_size],
- inputs_tensor,
- lora_weights,
- lora_indices_tensor,
- seq_len_tensor,
- batches,
- 1.0,
- op_type="expand",
- )
- slice_offset += hidden_size
- assert_close(our_outputs, ref_outputs)
|