test_lazy_outlines.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import sys
  2. from aphrodite import LLM, SamplingParams
  3. def test_lazy_outlines(sample_regex):
  4. """If users don't use guided decoding, outlines should not be imported.
  5. """
  6. prompts = [
  7. "Hello, my name is",
  8. "The president of the United States is",
  9. "The capital of France is",
  10. "The future of AI is",
  11. ]
  12. sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  13. llm = LLM(model="facebook/opt-125m",
  14. enforce_eager=True,
  15. gpu_memory_utilization=0.3)
  16. outputs = llm.generate(prompts, sampling_params)
  17. for output in outputs:
  18. prompt = output.prompt
  19. generated_text = output.outputs[0].text
  20. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  21. # make sure outlines is not imported
  22. assert 'outlines' not in sys.modules
  23. llm = LLM(model="facebook/opt-125m",
  24. enforce_eager=True,
  25. guided_decoding_backend="lm-format-enforcer",
  26. gpu_memory_utilization=0.3)
  27. sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
  28. outputs = llm.generate(
  29. prompts=[
  30. f"Give an example IPv4 address with this regex: {sample_regex}"
  31. ] * 2,
  32. sampling_params=sampling_params,
  33. use_tqdm=True,
  34. guided_options_request=dict(guided_regex=sample_regex))
  35. for output in outputs:
  36. prompt = output.prompt
  37. generated_text = output.outputs[0].text
  38. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  39. # make sure outlines is not imported
  40. assert 'outlines' not in sys.modules