hipify.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/env python3
  2. #
  3. # A command line tool for running pytorch's hipify preprocessor on CUDA
  4. # source files.
  5. #
  6. # See https://github.com/ROCm/hipify_torch
  7. # and <torch install dir>/utils/hipify/hipify_python.py
  8. #
  9. import argparse
  10. import os
  11. import shutil
  12. from torch.utils.hipify.hipify_python import hipify
  13. if __name__ == '__main__':
  14. parser = argparse.ArgumentParser()
  15. # Project directory where all the source + include files live.
  16. parser.add_argument(
  17. "-p",
  18. "--project_dir",
  19. help="The project directory.",
  20. )
  21. # Directory where hipified files are written.
  22. parser.add_argument(
  23. "-o",
  24. "--output_dir",
  25. help="The output directory.",
  26. )
  27. # Source files to convert.
  28. parser.add_argument("sources",
  29. help="Source files to hipify.",
  30. nargs="*",
  31. default=[])
  32. args = parser.parse_args()
  33. # Limit include scope to project_dir only
  34. includes = [os.path.join(args.project_dir, '*')]
  35. # Get absolute path for all source files.
  36. extra_files = [os.path.abspath(s) for s in args.sources]
  37. # Copy sources from project directory to output directory.
  38. # The directory might already exist to hold object files so we ignore that.
  39. shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
  40. hipify_result = hipify(project_directory=args.project_dir,
  41. output_directory=args.output_dir,
  42. header_include_dirs=[],
  43. includes=includes,
  44. extra_files=extra_files,
  45. show_detailed=True,
  46. is_pytorch_extension=True,
  47. hipify_extra_files_only=True)
  48. hipified_sources = []
  49. for source in args.sources:
  50. s_abs = os.path.abspath(source)
  51. hipified_s_abs = (hipify_result[s_abs].hipified_path if
  52. (s_abs in hipify_result
  53. and hipify_result[s_abs].hipified_path is not None)
  54. else s_abs)
  55. hipified_sources.append(hipified_s_abs)
  56. assert (len(hipified_sources) == len(args.sources))
  57. # Print hipified source files.
  58. print("\n".join(hipified_sources))