123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import operator
- import torch
- import torch.fx as fx
- def fix_functionalization(graph: fx.Graph):
- """
- Rewrite the graph module to replace the pattern involving
- torch._higher_order_ops.auto_functionalize.auto_functionalized
- with a direct call to the inplace custom op.
- # TODO: check if PyTorch nightly has fixed this issue
- """
- # debug code, if we want to see the graph before the transformation
- # with open("before.py", "w") as f:
- # print(graph.python_code(root_module="self", verbose=True).src, file=f)
- nodes_to_remove = []
- for node in graph.nodes:
- # Identify the auto_functionalized node
- if (
- node.op == "call_function"
- and node.target
- == torch._higher_order_ops.auto_functionalize.auto_functionalized
- ): # noqa
- if node.args[0] == torch.ops._C.rotary_embedding.default:
- # manual replace for rotary_embedding
- # Now, collect the arguments
- kwargs = node.kwargs
- query = kwargs["query"]
- mm_node = query.args[0].args[0]
- # Create a new call to torch.ops._C.rotary_embedding.default
- with graph.inserting_before(node):
- # just insert the call to the custom op
- # NOTE: don't run dead code elimination,
- # otherwise this op will be removed
- graph.call_function(
- torch.ops._C.rotary_embedding.default, kwargs=kwargs
- )
- # Remove the auto_functionalized node
- # Since the node may have outputs, we need to handle its users
- # Replace uses of the outputs (getitem nodes) with mm_node
- for user in list(node.users):
- if (
- user.op == "call_function"
- and user.target == operator.getitem
- ): # noqa
- # Remove the getitem node
- for getitem_user in list(user.users):
- if (
- getitem_user.op == "call_function"
- and getitem_user.target
- == torch.ops.aten.slice_scatter.default
- ):
- # Replace the uses of slice_scatter node
- # with mm_node
- getitem_user.replace_all_uses_with(mm_node)
- nodes_to_remove.append(getitem_user)
- nodes_to_remove.append(user)
- nodes_to_remove.append(node)
- elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
- # manual replace for fused_add_rms_norm
- # this is the most effective optimization for llama
- # failing to do this will result in many unnecessary copies
- kwargs = node.kwargs
- input = kwargs["input"]
- residual = kwargs["residual"]
- # Create a new call to torch.ops._C.rotary_embedding.default
- with graph.inserting_before(node):
- # just insert the call to the custom op
- # NOTE: don't run dead code elimination,
- # otherwise this op will be removed
- graph.call_function(
- torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs
- )
- for user in list(node.users):
- if (
- user.op == "call_function"
- and user.target == operator.getitem
- ): # noqa
- # Remove the getitem node
- if user.args[1] == 1:
- replace_node = input
- elif user.args[1] == 2:
- replace_node = residual
- user.replace_all_uses_with(replace_node)
- nodes_to_remove.append(user)
- nodes_to_remove.append(node)
- elif node.args[0] == torch.ops._C.rms_norm.default:
- # manual replace for rms_norm
- kwargs = node.kwargs
- input = kwargs["input"]
- out = kwargs["out"]
- weight = kwargs["weight"]
- epsilon = kwargs["epsilon"]
- # Create a new call to torch.ops._C.rotary_embedding.default
- # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
- with graph.inserting_before(node):
- # just insert the call to the custom op
- # NOTE: don't run dead code elimination,
- # otherwise this op will be removed
- graph.call_function(
- torch.ops._C.rms_norm.default,
- args=(out, input, weight, epsilon),
- )
- replace_node = out
- for user in list(node.users):
- if (
- user.op == "call_function"
- and user.target == operator.getitem
- ): # noqa
- user.replace_all_uses_with(replace_node)
- nodes_to_remove.append(user)
- nodes_to_remove.append(node)
- elif node.args[0] == torch.ops._C.silu_and_mul.default:
- # manual replace for silu_and_mul
- kwargs = node.kwargs
- input = kwargs["input"]
- out = kwargs["out"]
- # Create a new call to torch.ops._C.rotary_embedding.default
- # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
- with graph.inserting_before(node):
- # just insert the call to the custom op
- # NOTE: don't run dead code elimination,
- # otherwise this op will be removed
- graph.call_function(
- torch.ops._C.silu_and_mul.default,
- args=(out, input),
- )
- replace_node = out
- for user in list(node.users):
- if (
- user.op == "call_function"
- and user.target == operator.getitem
- ): # noqa
- user.replace_all_uses_with(replace_node)
- nodes_to_remove.append(user)
- nodes_to_remove.append(node)
- # Remove the nodes all at once
- for node in nodes_to_remove:
- graph.erase_node(node)
- # debug code, if we want to see the graph after the transformation
- # with open("after.py", "w") as f:
- # print(graph.python_code(root_module="self", verbose=True).src, file=f)
- def aphrodite_backend(graph, example_inputs):
- from torch._inductor import config
- current_config = config.shallow_copy_dict()
- from torch._inductor.compile_fx import compile_fx
- current_config["post_grad_custom_post_pass"] = fix_functionalization
- return compile_fx(graph, example_inputs, config_patches=current_config)
|