hadamard_example.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. from aphrodite.quantization.quip_utils import (
  3. hadamard_transform,
  4. get_hadK,
  5. matmul_hadU_cuda,
  6. matmul_hadUt_cuda
  7. )
  8. # Example 1: Basic Hadamard Transform
  9. def example_hadamard():
  10. # Create a random tensor
  11. x = torch.randn(4, 4) # Must be power of 2 dimensions
  12. # Apply Hadamard transform
  13. transformed = hadamard_transform(x, scale=1.0)
  14. # Inverse transform (using the same function with appropriate scale)
  15. inverse = hadamard_transform(transformed, scale=1.0)
  16. print("Original shape:", x.shape)
  17. print("Transformed shape:", transformed.shape)
  18. print("Reconstruction error:", torch.norm(x - inverse))
  19. # Example 2: Using Hadamard-based matrix multiplication
  20. def example_hadamard_matmul():
  21. # Create input tensor
  22. batch_size = 2
  23. n = 16 # dimension size (power of 2)
  24. x = torch.randn(batch_size, n)
  25. # Get Hadamard matrices and parameters
  26. hadK, K, padded_n = get_hadK(n, use_rand=True)
  27. # Forward transform
  28. transformed = matmul_hadU_cuda(x, hadK, K, padded_n)
  29. # Backward transform
  30. reconstructed = matmul_hadUt_cuda(transformed, hadK, K, padded_n)
  31. print("Input shape:", x.shape)
  32. print("Transformed shape:", transformed.shape)
  33. print("Reconstruction error:", torch.norm(x - reconstructed))
  34. # Example 3: Working with non-power-of-2 dimensions
  35. def example_non_power_2():
  36. # Create tensor with non-power-of-2 dimension
  37. x = torch.randn(3, 10)
  38. # Get appropriate Hadamard matrices and padding
  39. hadK, K, padded_n = get_hadK(x.shape[-1], use_rand=True)
  40. # Forward transform (will handle padding automatically)
  41. transformed = matmul_hadU_cuda(x, hadK, K, padded_n)
  42. # Backward transform
  43. reconstructed = matmul_hadUt_cuda(transformed, hadK, K, padded_n)
  44. # Remove padding from result
  45. reconstructed = reconstructed[..., :x.shape[-1]]
  46. print("Original shape:", x.shape)
  47. print("Padded transformed shape:", transformed.shape)
  48. print("Final reconstructed shape:", reconstructed.shape)
  49. print("Reconstruction error:", torch.norm(x - reconstructed))
  50. if __name__ == "__main__":
  51. print("Example 1: Basic Hadamard Transform")
  52. example_hadamard()
  53. print("\nExample 2: Hadamard Matrix Multiplication")
  54. example_hadamard_matmul()
  55. print("\nExample 3: Non-power-of-2 Dimensions")
  56. example_non_power_2()