generator.py 871 B

123456789101112131415161718192021222324252627
  1. DTYPES = ["fp16", "bf16", "fp32"]
  2. DTYPE_MAP = {
  3. "fp16": "nv_half",
  4. "bf16": "nv_bfloat16",
  5. "fp32": "float",
  6. }
  7. TEMPLATE = """
  8. #include "bgmv_config.h"
  9. #include "bgmv_impl.cuh"
  10. FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
  11. """.lstrip()
  12. for input_dtype in DTYPES:
  13. for output_dtype in DTYPES:
  14. for weight_dtype in DTYPES:
  15. if weight_dtype == "fp32":
  16. # FP32 weights are not supported.
  17. continue
  18. kernel_definition = TEMPLATE.format(
  19. input_dtype=DTYPE_MAP[input_dtype],
  20. output_dtype=DTYPE_MAP[output_dtype],
  21. weight_dtype=DTYPE_MAP[weight_dtype])
  22. filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
  23. with open(filename, "w") as f:
  24. f.write(kernel_definition)