encoder_decoder_inference.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. """Prompting encoder-decoder models, specifically the BART model."""
  2. from aphrodite import LLM, SamplingParams
  3. from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
  4. TokensPrompt, zip_enc_dec_prompts)
  5. dtype = "float"
  6. # Create a BART encoder/decoder model instance
  7. llm = LLM(
  8. model="facebook/bart-large-cnn",
  9. dtype=dtype,
  10. )
  11. # Get BART tokenizer
  12. tokenizer = llm.llm_engine.get_tokenizer_group()
  13. # Test prompts
  14. #
  15. # This section shows all of the valid ways to prompt an
  16. # encoder/decoder model.
  17. #
  18. # - Helpers for building prompts
  19. text_prompt_raw = "Hello, my name is"
  20. text_prompt = TextPrompt(prompt="The president of the United States is")
  21. tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
  22. prompt="The capital of France is"))
  23. # - Pass a single prompt to encoder/decoder model
  24. # (implicitly encoder input prompt);
  25. # decoder input prompt is assumed to be None
  26. single_text_prompt_raw = text_prompt_raw # Pass a string directly
  27. single_text_prompt = text_prompt # Pass a TextPrompt
  28. single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
  29. # - Pass explicit encoder and decoder input prompts within one data structure.
  30. # Encoder and decoder prompts can both independently be text or tokens, with
  31. # no requirement that they be the same prompt type. Some example prompt-type
  32. # combinations are shown below, note that these are not exhaustive.
  33. enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
  34. # Pass encoder prompt string directly, &
  35. # pass decoder prompt tokens
  36. encoder_prompt=single_text_prompt_raw,
  37. decoder_prompt=single_tokens_prompt,
  38. )
  39. enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
  40. # Pass TextPrompt to encoder, and
  41. # pass decoder prompt string directly
  42. encoder_prompt=single_text_prompt,
  43. decoder_prompt=single_text_prompt_raw,
  44. )
  45. enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
  46. # Pass encoder prompt tokens directly, and
  47. # pass TextPrompt to decoder
  48. encoder_prompt=single_tokens_prompt,
  49. decoder_prompt=single_text_prompt,
  50. )
  51. # - Finally, here's a useful helper function for zipping encoder and
  52. # decoder prompts together into a list of ExplicitEncoderDecoderPrompt
  53. # instances
  54. zipped_prompt_list = zip_enc_dec_prompts(
  55. ['An encoder prompt', 'Another encoder prompt'],
  56. ['A decoder prompt', 'Another decoder prompt'])
  57. # - Let's put all of the above example prompts together into one list
  58. # which we will pass to the encoder/decoder LLM.
  59. prompts = [
  60. single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
  61. enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
  62. ] + zipped_prompt_list
  63. print(prompts)
  64. # Create a sampling params object.
  65. sampling_params = SamplingParams(
  66. temperature=0,
  67. top_p=1.0,
  68. min_tokens=0,
  69. max_tokens=20,
  70. )
  71. # Generate output tokens from the prompts. The output is a list of
  72. # RequestOutput objects that contain the prompt, generated
  73. # text, and other information.
  74. outputs = llm.generate(prompts, sampling_params)
  75. # Print the outputs.
  76. for output in outputs:
  77. prompt = output.prompt
  78. encoder_prompt = output.encoder_prompt
  79. generated_text = output.outputs[0].text
  80. print(f"Encoder prompt: {encoder_prompt!r}, "
  81. f"Decoder prompt: {prompt!r}, "
  82. f"Generated text: {generated_text!r}")