test_compilation.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import glob
  2. import os
  3. import runpy
  4. import tempfile
  5. import depyf
  6. # disable custom dispatcher, let Dynamo takes over
  7. # all the control
  8. os.environ['APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
  9. temp_dir = tempfile.mkdtemp()
  10. with depyf.prepare_debug(temp_dir):
  11. cur_dir = os.path.dirname(__file__)
  12. parent_dir = os.path.dirname(cur_dir)
  13. root_dir = os.path.dirname(parent_dir)
  14. example_file = os.path.join(root_dir, "examples",
  15. "offline_inference",
  16. "tpu_inference.py")
  17. runpy.run_path(example_file)
  18. compiled_code = sorted(
  19. glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
  20. # we should only trigger Dynamo compilation three times:
  21. # one for the profiling phase without kv cache
  22. # one for the prefill phase with symbolic shapes
  23. # one for the decode phase with symbolic shapes
  24. # and later calls should not trigger Dynamo compilation again.
  25. # NOTE: it might still trigger XLA compilation.
  26. # check we have three compiled code
  27. # this is the assumption when we use the custom dispatcher
  28. assert len(compiled_code) == 3
  29. # check all the compilations are as expected
  30. compiled_fn = sorted(
  31. glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py")))
  32. # the first compilation is the profiling phase,
  33. # it should not have any kv cache
  34. with open(compiled_fn[0]) as f:
  35. content = f.read()
  36. assert "kv_caches" not in content
  37. # the second compilation is the prefill phase,
  38. # it should have kv cache and the flash_attention op
  39. with open(compiled_fn[1]) as f:
  40. content = f.read()
  41. assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content
  42. # the third compilation is the decode phase,
  43. # it should have kv cache and the paged_attention op
  44. with open(compiled_fn[2]) as f:
  45. content = f.read()
  46. assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content