generator.py 977 B

12345678910111213141516171819202122232425262728
  1. DTYPES = ["fp16", "bf16", "fp32"]
  2. DTYPE_MAP = {
  3. "fp16": "nv_half",
  4. "bf16": "nv_bfloat16",
  5. "fp32": "float",
  6. }
  7. TEMPLATE = """
  8. #include "bgmv_config.h"
  9. #include "bgmv_impl.cuh"
  10. FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
  11. FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
  12. """.lstrip() # noqa: E501
  13. for input_dtype in DTYPES:
  14. for output_dtype in DTYPES:
  15. for weight_dtype in DTYPES:
  16. if weight_dtype == "fp32":
  17. # FP32 weights are not supported.
  18. continue
  19. kernel_definition = TEMPLATE.format(
  20. input_dtype=DTYPE_MAP[input_dtype],
  21. output_dtype=DTYPE_MAP[output_dtype],
  22. weight_dtype=DTYPE_MAP[weight_dtype])
  23. filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
  24. with open(filename, "w") as f:
  25. f.write(kernel_definition)