test_multi_adapter_inference.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
  2. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  3. MODEL_PATH = "bigscience/bloomz-560m"
  4. pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
  5. pa_path2 = 'swapnilbp/angry_tweet_ptune'
  6. def do_sample(engine):
  7. prompts = [
  8. ("Tweet text: I have complaints! Label: ",
  9. SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
  10. PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
  11. ("Tweet text: I have no problems Label: ",
  12. SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
  13. PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
  14. ("Tweet text: I have complaints! Label: ",
  15. SamplingParams(temperature=0.0, max_tokens=3), None),
  16. ("Tweet text: I have no problems Label: ",
  17. SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
  18. PromptAdapterRequest("complain", 3, pa_path, 8)),
  19. ]
  20. request_id = 0
  21. results = set()
  22. while prompts or engine.has_unfinished_requests():
  23. if prompts:
  24. prompt, sampling_params, pa_request = prompts.pop(0)
  25. engine.add_request(str(request_id),
  26. prompt,
  27. sampling_params,
  28. prompt_adapter_request=pa_request)
  29. request_id += 1
  30. request_outputs = engine.step()
  31. for request_output in request_outputs:
  32. if request_output.finished:
  33. results.add(request_output.outputs[0].text)
  34. return results
  35. def test_multi_prompt_adapters():
  36. engine_args = EngineArgs(model=MODEL_PATH,
  37. max_prompt_adapters=3,
  38. enable_prompt_adapter=True,
  39. max_prompt_adapter_token=8)
  40. engine = AphroditeEngine.from_engine_args(engine_args)
  41. expected_output = {
  42. ' quot;I', 'hate speech', 'no complaint', 'not hate speech'
  43. }
  44. assert do_sample(engine) == expected_output