123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import os
- import sys
- from abc import abstractmethod
- from contextlib import contextmanager
- from types import CodeType
- from typing import Callable, List
- import torch
- import aphrodite.common.envs as envs
- class TorchCompileWrapperWithCustomDispacther:
- """
- A wrapper class for torch.compile, with a custom dispatch logic.
- Subclasses should:
- 1. Implement the forward method
- 2. Implement the dispatch logic in the __call__ method
- It can use `self.compiled_codes` to access the compiled bytecode,
- and `with self.dispatch_to_code(index):` to dispatch to
- the compiled code.
- 3. Implement the `__init__` method to determine how to call
- `torch.compile` over the forward method.
- """
- def __init__(self, compiled_callable: Callable):
- self.compiled_callable = compiled_callable
- self.original_code_object = self.__class__.forward.__code__
- self.compiled_codes: List[CodeType] = []
- torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
- # read the env var to determine whether to use the custom dispatcher
- # subclasses can use this to switch between the custom dispatcher
- # and the default Dynamo guard mechanism.
- self.use_custom_dispatcher: bool = (
- envs.APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER
- )
- def __call__(self, *args, **kwargs):
- """Implement the dispatch logic here, beyond the torch.compile level.
- NOTE: this function can have additional arguments beyond the forward
- method, for directly dispatching to the compiled code.
- """
- return self.compiled_callable(*args, **kwargs)
- @abstractmethod
- def forward(self, *args, **kwargs):
- ...
- def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
- """Hook to save the compiled bytecode for direct execution."""
- if old_code is not self.original_code_object:
- return
- # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
- frame = sys._getframe()
- while True:
- frame = frame.f_back
- code_name = frame.f_code.co_name
- file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
- if code_name == "_compile" and file_name == "convert_frame.py":
- break
- frame = frame.f_locals["frame"]
- assert frame.f_code == old_code
- if frame.f_locals["self"] is not self:
- return
- self.compiled_codes.append(new_code)
- @contextmanager
- def dispatch_to_code(self, index: int):
- """Context manager to dispatch to the compiled code.
- Why does this work? Because Dynamo guarantees that the compiled
- bytecode has exactly the same arguments, cell variables, and free
- variables as the original code. Therefore we can directly switch
- the code object in the function and call it.
- See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
- """ # noqa
- self.__class__.forward.__code__ = self.compiled_codes[index]
- yield
- self.__class__.forward.__code__ = self.original_code_object
|