backends.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import operator
  2. import torch
  3. import torch.fx as fx
  4. def fix_functionalization(graph: fx.Graph):
  5. """
  6. Rewrite the graph module to replace the pattern involving
  7. torch._higher_order_ops.auto_functionalize.auto_functionalized
  8. with a direct call to the inplace custom op.
  9. # TODO: check if PyTorch nightly has fixed this issue
  10. """
  11. # debug code, if we want to see the graph before the transformation
  12. # with open("before.py", "w") as f:
  13. # print(graph.python_code(root_module="self", verbose=True).src, file=f)
  14. nodes_to_remove = []
  15. for node in graph.nodes:
  16. # Identify the auto_functionalized node
  17. if (
  18. node.op == "call_function"
  19. and node.target
  20. == torch._higher_order_ops.auto_functionalize.auto_functionalized
  21. ): # noqa
  22. if node.args[0] == torch.ops._C.rotary_embedding.default:
  23. # manual replace for rotary_embedding
  24. # Now, collect the arguments
  25. kwargs = node.kwargs
  26. query = kwargs["query"]
  27. mm_node = query.args[0].args[0]
  28. # Create a new call to torch.ops._C.rotary_embedding.default
  29. with graph.inserting_before(node):
  30. # just insert the call to the custom op
  31. # NOTE: don't run dead code elimination,
  32. # otherwise this op will be removed
  33. graph.call_function(
  34. torch.ops._C.rotary_embedding.default, kwargs=kwargs
  35. )
  36. # Remove the auto_functionalized node
  37. # Since the node may have outputs, we need to handle its users
  38. # Replace uses of the outputs (getitem nodes) with mm_node
  39. for user in list(node.users):
  40. if (
  41. user.op == "call_function"
  42. and user.target == operator.getitem
  43. ): # noqa
  44. # Remove the getitem node
  45. for getitem_user in list(user.users):
  46. if (
  47. getitem_user.op == "call_function"
  48. and getitem_user.target
  49. == torch.ops.aten.slice_scatter.default
  50. ):
  51. # Replace the uses of slice_scatter node
  52. # with mm_node
  53. getitem_user.replace_all_uses_with(mm_node)
  54. nodes_to_remove.append(getitem_user)
  55. nodes_to_remove.append(user)
  56. nodes_to_remove.append(node)
  57. elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
  58. # manual replace for fused_add_rms_norm
  59. # this is the most effective optimization for llama
  60. # failing to do this will result in many unnecessary copies
  61. kwargs = node.kwargs
  62. input = kwargs["input"]
  63. residual = kwargs["residual"]
  64. # Create a new call to torch.ops._C.rotary_embedding.default
  65. with graph.inserting_before(node):
  66. # just insert the call to the custom op
  67. # NOTE: don't run dead code elimination,
  68. # otherwise this op will be removed
  69. graph.call_function(
  70. torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs
  71. )
  72. for user in list(node.users):
  73. if (
  74. user.op == "call_function"
  75. and user.target == operator.getitem
  76. ): # noqa
  77. # Remove the getitem node
  78. if user.args[1] == 1:
  79. replace_node = input
  80. elif user.args[1] == 2:
  81. replace_node = residual
  82. user.replace_all_uses_with(replace_node)
  83. nodes_to_remove.append(user)
  84. nodes_to_remove.append(node)
  85. elif node.args[0] == torch.ops._C.rms_norm.default:
  86. # manual replace for rms_norm
  87. kwargs = node.kwargs
  88. input = kwargs["input"]
  89. out = kwargs["out"]
  90. weight = kwargs["weight"]
  91. epsilon = kwargs["epsilon"]
  92. # Create a new call to torch.ops._C.rotary_embedding.default
  93. # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
  94. with graph.inserting_before(node):
  95. # just insert the call to the custom op
  96. # NOTE: don't run dead code elimination,
  97. # otherwise this op will be removed
  98. graph.call_function(
  99. torch.ops._C.rms_norm.default,
  100. args=(out, input, weight, epsilon),
  101. )
  102. replace_node = out
  103. for user in list(node.users):
  104. if (
  105. user.op == "call_function"
  106. and user.target == operator.getitem
  107. ): # noqa
  108. user.replace_all_uses_with(replace_node)
  109. nodes_to_remove.append(user)
  110. nodes_to_remove.append(node)
  111. elif node.args[0] == torch.ops._C.silu_and_mul.default:
  112. # manual replace for silu_and_mul
  113. kwargs = node.kwargs
  114. input = kwargs["input"]
  115. out = kwargs["out"]
  116. # Create a new call to torch.ops._C.rotary_embedding.default
  117. # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
  118. with graph.inserting_before(node):
  119. # just insert the call to the custom op
  120. # NOTE: don't run dead code elimination,
  121. # otherwise this op will be removed
  122. graph.call_function(
  123. torch.ops._C.silu_and_mul.default,
  124. args=(out, input),
  125. )
  126. replace_node = out
  127. for user in list(node.users):
  128. if (
  129. user.op == "call_function"
  130. and user.target == operator.getitem
  131. ): # noqa
  132. user.replace_all_uses_with(replace_node)
  133. nodes_to_remove.append(user)
  134. nodes_to_remove.append(node)
  135. # Remove the nodes all at once
  136. for node in nodes_to_remove:
  137. graph.erase_node(node)
  138. # debug code, if we want to see the graph after the transformation
  139. # with open("after.py", "w") as f:
  140. # print(graph.python_code(root_module="self", verbose=True).src, file=f)
  141. def aphrodite_backend(graph, example_inputs):
  142. from torch._inductor import config
  143. current_config = config.shallow_copy_dict()
  144. from torch._inductor.compile_fx import compile_fx
  145. current_config["post_grad_custom_post_pass"] = fix_functionalization
  146. return compile_fx(graph, example_inputs, config_patches=current_config)