generator.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import math
  2. import re
  3. from pathlib import Path
  4. import numpy as np
  5. # From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5)
  6. had_12_paley = """
  7. +-++++++++++
  8. --+-+-+-+-+-
  9. +++-++----++
  10. +---+--+-++-
  11. +++++-++----
  12. +-+---+--+-+
  13. ++--+++-++--
  14. +--++---+--+
  15. ++----+++-++
  16. +--+-++---+-
  17. ++++----+++-
  18. +-+--+-++---
  19. """
  20. # From http://neilsloane.com/hadamard/
  21. had_20_will = """
  22. +----+----++--++-++-
  23. -+----+---+++---+-++
  24. --+----+---+++-+-+-+
  25. ---+----+---+++++-+-
  26. ----+----++--++-++-+
  27. -+++++-----+--+++--+
  28. +-+++-+---+-+--+++--
  29. ++-++--+---+-+--+++-
  30. +++-+---+---+-+--+++
  31. ++++-----++--+-+--++
  32. --++-+-++-+-----++++
  33. ---++-+-++-+---+-+++
  34. +---++-+-+--+--++-++
  35. ++---++-+----+-+++-+
  36. -++---++-+----+++++-
  37. -+--+--++-+----+----
  38. +-+-----++-+----+---
  39. -+-+-+---+--+----+--
  40. --+-+++------+----+-
  41. +--+--++------+----+
  42. """
  43. had_28_will = """
  44. +------++----++-+--+-+--++--
  45. -+-----+++-----+-+--+-+--++-
  46. --+-----+++---+-+-+----+--++
  47. ---+-----+++---+-+-+-+--+--+
  48. ----+-----+++---+-+-+++--+--
  49. -----+-----++++--+-+--++--+-
  50. ------++----++-+--+-+--++--+
  51. --++++-+-------++--+++-+--+-
  52. ---++++-+-----+-++--+-+-+--+
  53. +---+++--+----++-++--+-+-+--
  54. ++---++---+----++-++--+-+-+-
  55. +++---+----+----++-++--+-+-+
  56. ++++--------+-+--++-++--+-+-
  57. -++++--------+++--++--+--+-+
  58. -+-++-++--++--+--------++++-
  59. +-+-++--+--++--+--------++++
  60. -+-+-++--+--++--+----+---+++
  61. +-+-+-++--+--+---+---++---++
  62. ++-+-+-++--+------+--+++---+
  63. -++-+-+-++--+------+-++++---
  64. +-++-+---++--+------+-++++--
  65. -++--++-+-++-+++----++------
  66. +-++--++-+-++-+++-----+-----
  67. ++-++---+-+-++-+++-----+----
  68. -++-++-+-+-+-+--+++-----+---
  69. --++-++++-+-+----+++-----+--
  70. +--++-+-++-+-+----+++-----+-
  71. ++--++-+-++-+-+----++------+
  72. """
  73. header = """
  74. /******************************************************************************
  75. * Copyright (c) 2023, Tri Dao.
  76. ******************************************************************************/
  77. // This file is auto-generated. See "generator.py"\n
  78. #pragma once
  79. """
  80. template = """
  81. __device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {
  82. float out[{N}];
  83. {code}
  84. #pragma unroll
  85. for (int i = 0; i < {N}; i++) { x[i] = out[i]; }
  86. }
  87. """
  88. def string_to_array(string):
  89. # Convert strings of + and - to bool arrays
  90. string = string.strip().replace('+', '1').replace('-', '-1').split()
  91. return np.stack([np.fromstring(" ".join(string[i]), dtype=np.int32, sep=' ') for i in range(len(string))])
  92. def array_code_gen(arr):
  93. N = arr.shape[0]
  94. assert arr.shape[0] == arr.shape[1]
  95. out = []
  96. for i in range(N):
  97. out.append(f"out[{i}] = " + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + ";")
  98. return template.replace("{N}", str(N)).replace("{code}", '\n '.join(out))
  99. def main():
  100. output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h"
  101. output_dir.write_text(header + array_code_gen(string_to_array(had_12_paley)) + array_code_gen(string_to_array(had_20_will)) + array_code_gen(string_to_array(had_28_will)))
  102. if __name__ == '__main__':
  103. main()