perplexity.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import numpy as np
  2. from datasets import load_dataset
  3. from transformers import AutoTokenizer
  4. from aphrodite import LLM, SamplingParams
  5. # Load the wikitext2 dataset.
  6. dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
  7. # Get the first 2000 elements from the 'train' split.
  8. prompts = dataset['train']['text'][:2000]
  9. model_id = "mistralai/Mistral-7B-Instruct-v0.2"
  10. # Create a tokenizer.
  11. tokenizer = AutoTokenizer.from_pretrained(model_id)
  12. # Tokenize the prompts and discard or truncate any prompts longer than 2048 tokens.
  13. tokenized_prompts = [tokenizer.encode(prompt, truncation=True,
  14. max_length=4096) for prompt in prompts]
  15. # Detokenize the prompts.
  16. detokenized_prompts = [tokenizer.decode(tokens
  17. ) for tokens in tokenized_prompts]
  18. # Create a sampling params object.
  19. sampling_params = SamplingParams(
  20. temperature=0.0,
  21. ignore_eos=True,
  22. max_tokens=10,
  23. skip_special_tokens=False,
  24. spaces_between_special_tokens=False,
  25. logprobs=1,
  26. prompt_logprobs=1,
  27. )
  28. # Create an LLM.
  29. llm = LLM(model=model_id)
  30. # Generate texts from the detokenized prompts.
  31. outputs = llm.generate(detokenized_prompts, sampling_params)
  32. # Calculate the perplexity.
  33. all_logprobs = []
  34. for output in outputs:
  35. all_logprobs.extend([next(iter(lp.values())) for lp in output.prompt_logprobs[1:]])
  36. all_logprobs = np.array([lp.logprob for lp in all_logprobs])
  37. # NOTE: we need to divide by 2 to match the perplexity results
  38. # for the same model on llama.cpp. I'm unsure if this
  39. # approach to ppx measurement is correct.
  40. perplexity = (np.exp(-all_logprobs.mean())) / 2
  41. print(f"Perplexity: {perplexity}")