test_wrapper.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from typing import Optional
  2. import torch
  3. from aphrodite.compilation.wrapper import (
  4. TorchCompileWrapperWithCustomDispacther)
  5. class MyMod(torch.nn.Module):
  6. def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
  7. if cache is not None:
  8. return x + cache
  9. return x * 2
  10. class MyWrapper(TorchCompileWrapperWithCustomDispacther):
  11. def __init__(self, model):
  12. self.model = model
  13. compiled_callable = torch.compile(self.forward, backend="eager")
  14. super().__init__(compiled_callable)
  15. def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
  16. # this is the function to be compiled
  17. return self.model(x, cache)
  18. def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
  19. # let torch.compile compile twice
  20. if len(self.compiled_codes) == 2:
  21. dispatch_id = 0 if cache is None else 1
  22. with self.dispatch_to_code(dispatch_id):
  23. return self.forward(x, cache)
  24. else:
  25. return self.compiled_callable(x, cache)
  26. def test_torch_compile_wrapper():
  27. mod = MyMod()
  28. wrappers = []
  29. for i in range(3):
  30. torch._dynamo.reset()
  31. wrapper = MyWrapper(mod)
  32. wrappers.append(wrapper)
  33. x = torch.tensor([1])
  34. wrapper(x, None) # profile run, compile
  35. # create a cache tensor
  36. cache = torch.tensor([2])
  37. wrapper(x, cache) # warm up with cache, recompile
  38. # for new input, dispatch to the compiled code directly
  39. new_x = torch.tensor([3])
  40. assert (
  41. wrapper(new_x, None).item() == 6
  42. ) # dispatch to the first compiled code
  43. assert (
  44. wrapper(new_x, cache).item() == 5
  45. ) # dispatch to the second compiled code
  46. for wrapper in wrappers:
  47. # make sure they have independent compiled codes
  48. assert len(wrapper.compiled_codes) == 2