import operator import torch import torch.fx as fx def fix_functionalization(graph: fx.Graph): """ Rewrite the graph module to replace the pattern involving torch._higher_order_ops.auto_functionalize.auto_functionalized with a direct call to the inplace custom op. # TODO: check if PyTorch nightly has fixed this issue """ # debug code, if we want to see the graph before the transformation # with open("before.py", "w") as f: # print(graph.python_code(root_module="self", verbose=True).src, file=f) nodes_to_remove = [] for node in graph.nodes: # Identify the auto_functionalized node if ( node.op == "call_function" and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized ): # noqa if node.args[0] == torch.ops._C.rotary_embedding.default: # manual replace for rotary_embedding # Now, collect the arguments kwargs = node.kwargs query = kwargs["query"] mm_node = query.args[0].args[0] # Create a new call to torch.ops._C.rotary_embedding.default with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( torch.ops._C.rotary_embedding.default, kwargs=kwargs ) # Remove the auto_functionalized node # Since the node may have outputs, we need to handle its users # Replace uses of the outputs (getitem nodes) with mm_node for user in list(node.users): if ( user.op == "call_function" and user.target == operator.getitem ): # noqa # Remove the getitem node for getitem_user in list(user.users): if ( getitem_user.op == "call_function" and getitem_user.target == torch.ops.aten.slice_scatter.default ): # Replace the uses of slice_scatter node # with mm_node getitem_user.replace_all_uses_with(mm_node) nodes_to_remove.append(getitem_user) nodes_to_remove.append(user) nodes_to_remove.append(node) elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: # manual replace for fused_add_rms_norm # this is the most effective optimization for llama # failing to do this will result in many unnecessary copies kwargs = node.kwargs input = kwargs["input"] residual = kwargs["residual"] # Create a new call to torch.ops._C.rotary_embedding.default with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs ) for user in list(node.users): if ( user.op == "call_function" and user.target == operator.getitem ): # noqa # Remove the getitem node if user.args[1] == 1: replace_node = input elif user.args[1] == 2: replace_node = residual user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) elif node.args[0] == torch.ops._C.rms_norm.default: # manual replace for rms_norm kwargs = node.kwargs input = kwargs["input"] out = kwargs["out"] weight = kwargs["weight"] epsilon = kwargs["epsilon"] # Create a new call to torch.ops._C.rotary_embedding.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( torch.ops._C.rms_norm.default, args=(out, input, weight, epsilon), ) replace_node = out for user in list(node.users): if ( user.op == "call_function" and user.target == operator.getitem ): # noqa user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) elif node.args[0] == torch.ops._C.silu_and_mul.default: # manual replace for silu_and_mul kwargs = node.kwargs input = kwargs["input"] out = kwargs["out"] # Create a new call to torch.ops._C.rotary_embedding.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( torch.ops._C.silu_and_mul.default, args=(out, input), ) replace_node = out for user in list(node.users): if ( user.op == "call_function" and user.target == operator.getitem ): # noqa user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) # Remove the nodes all at once for node in nodes_to_remove: graph.erase_node(node) # debug code, if we want to see the graph after the transformation # with open("after.py", "w") as f: # print(graph.python_code(root_module="self", verbose=True).src, file=f) def aphrodite_backend(graph, example_inputs): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx current_config["post_grad_custom_post_pass"] = fix_functionalization return compile_fx(graph, example_inputs, config_patches=current_config)