1
0

conftest.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from typing import List, Tuple
  2. import pytest
  3. import torch
  4. def create_kv_caches(
  5. num_blocks: int,
  6. block_size: int,
  7. num_layers: int,
  8. num_heads: int,
  9. head_size: int,
  10. dtype: torch.dtype,
  11. seed: int,
  12. device: str,
  13. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  14. torch.random.manual_seed(seed)
  15. torch.cuda.manual_seed(seed)
  16. scale = head_size**-0.5
  17. x = 16 // torch.tensor([], dtype=dtype).element_size()
  18. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  19. key_caches = []
  20. for _ in range(num_layers):
  21. key_cache = torch.empty(size=key_cache_shape,
  22. dtype=dtype,
  23. device=device)
  24. key_cache.uniform_(-scale, scale)
  25. key_caches.append(key_cache)
  26. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  27. value_caches = []
  28. for _ in range(num_layers):
  29. value_cache = torch.empty(size=value_cache_shape,
  30. dtype=dtype,
  31. device=device)
  32. value_cache.uniform_(-scale, scale)
  33. value_caches.append(value_cache)
  34. return key_caches, value_caches
  35. @pytest.fixture()
  36. def kv_cache_factory():
  37. return create_kv_caches