wrapper.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import sys
  3. from abc import abstractmethod
  4. from contextlib import contextmanager
  5. from types import CodeType
  6. from typing import Callable, List
  7. import torch
  8. import aphrodite.common.envs as envs
  9. class TorchCompileWrapperWithCustomDispacther:
  10. """
  11. A wrapper class for torch.compile, with a custom dispatch logic.
  12. Subclasses should:
  13. 1. Implement the forward method
  14. 2. Implement the dispatch logic in the __call__ method
  15. It can use `self.compiled_codes` to access the compiled bytecode,
  16. and `with self.dispatch_to_code(index):` to dispatch to
  17. the compiled code.
  18. 3. Implement the `__init__` method to determine how to call
  19. `torch.compile` over the forward method.
  20. """
  21. def __init__(self, compiled_callable: Callable):
  22. self.compiled_callable = compiled_callable
  23. self.original_code_object = self.__class__.forward.__code__
  24. self.compiled_codes: List[CodeType] = []
  25. torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
  26. # read the env var to determine whether to use the custom dispatcher
  27. # subclasses can use this to switch between the custom dispatcher
  28. # and the default Dynamo guard mechanism.
  29. self.use_custom_dispatcher: bool = (
  30. envs.APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER
  31. )
  32. def __call__(self, *args, **kwargs):
  33. """Implement the dispatch logic here, beyond the torch.compile level.
  34. NOTE: this function can have additional arguments beyond the forward
  35. method, for directly dispatching to the compiled code.
  36. """
  37. return self.compiled_callable(*args, **kwargs)
  38. @abstractmethod
  39. def forward(self, *args, **kwargs):
  40. ...
  41. def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
  42. """Hook to save the compiled bytecode for direct execution."""
  43. if old_code is not self.original_code_object:
  44. return
  45. # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
  46. frame = sys._getframe()
  47. while True:
  48. frame = frame.f_back
  49. code_name = frame.f_code.co_name
  50. file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
  51. if code_name == "_compile" and file_name == "convert_frame.py":
  52. break
  53. frame = frame.f_locals["frame"]
  54. assert frame.f_code == old_code
  55. if frame.f_locals["self"] is not self:
  56. return
  57. self.compiled_codes.append(new_code)
  58. @contextmanager
  59. def dispatch_to_code(self, index: int):
  60. """Context manager to dispatch to the compiled code.
  61. Why does this work? Because Dynamo guarantees that the compiled
  62. bytecode has exactly the same arguments, cell variables, and free
  63. variables as the original code. Therefore we can directly switch
  64. the code object in the function and call it.
  65. See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
  66. """ # noqa
  67. self.__class__.forward.__code__ = self.compiled_codes[index]
  68. yield
  69. self.__class__.forward.__code__ = self.original_code_object