conftest.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import List, Tuple
  2. import pytest
  3. import torch
  4. def create_kv_caches(
  5. num_blocks: int,
  6. block_size: int,
  7. num_layers: int,
  8. num_heads: int,
  9. head_size: int,
  10. dtype: torch.dtype,
  11. seed: int,
  12. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  13. torch.random.manual_seed(seed)
  14. torch.cuda.manual_seed(seed)
  15. scale = head_size**-0.5
  16. x = 16 // torch.tensor([], dtype=dtype).element_size()
  17. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  18. key_caches = []
  19. for _ in range(num_layers):
  20. key_cache = torch.empty(size=key_cache_shape,
  21. dtype=dtype,
  22. device='cuda')
  23. key_cache.uniform_(-scale, scale)
  24. key_caches.append(key_cache)
  25. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  26. values_caches = []
  27. for _ in range(num_layers):
  28. values_cache = torch.empty(size=value_cache_shape,
  29. dtype=dtype,
  30. device='cuda')
  31. values_cache.uniform_(-scale, scale)
  32. values_caches.append(values_cache)
  33. return key_caches, values_caches
  34. @pytest.fixture()
  35. def kv_cache_factory():
  36. return create_kv_caches