conftest.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import os
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from transformers import AutoModelForCausalLM
  6. from aphrodite import LLM, SamplingParams
  7. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  8. _TEST_DIR = os.path.dirname(__file__)
  9. _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
  10. _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
  11. def _read_prompts(filename: str) -> str:
  12. prompts = []
  13. with open(filename, "r") as f:
  14. prompt = f.readline()
  15. prompts.append(prompt)
  16. return prompts
  17. @pytest.fixture
  18. def example_prompts() -> List[str]:
  19. prompts = []
  20. for filename in _TEST_PROMPTS:
  21. prompts += _read_prompts(filename)
  22. return prompts
  23. @pytest.fixture
  24. def example_long_prompts() -> List[str]:
  25. prompts = []
  26. for filename in _LONG_PROMPTS:
  27. prompts += _read_prompts(filename)
  28. return prompts
  29. _STR_DTYPE_TO_TORCH_DTYPE = {
  30. "half": torch.half,
  31. "bfloat16": torch.bfloat16,
  32. "float": torch.float,
  33. }
  34. class HfRunner:
  35. def __init__(
  36. self,
  37. model_name: str,
  38. tokenizer_name: Optional[str] = None,
  39. dtype: str = "half",
  40. ) -> None:
  41. assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
  42. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  43. self.model = AutoModelForCausalLM.from_pretrained(
  44. model_name,
  45. torch_dtype=torch_dtype,
  46. trust_remote_code=True,
  47. ).cuda()
  48. if tokenizer_name is None:
  49. tokenizer_name = model_name
  50. self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
  51. def generate(
  52. self,
  53. prompts: List[str],
  54. **kwargs,
  55. ) -> List[Tuple[List[int], str]]:
  56. outputs: List[Tuple[List[int], str]] = []
  57. for prompt in prompts:
  58. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  59. output_ids = self.model.generate(
  60. input_ids.cuda(),
  61. use_cache=True,
  62. **kwargs,
  63. )
  64. output_str = self.tokenizer.batch_decode(
  65. output_ids,
  66. skip_special_tokens=True,
  67. clean_up_tokenization_spaces=False,
  68. )
  69. output_ids = output_ids.cpu().tolist()
  70. outputs.append((output_ids, output_str))
  71. return outputs
  72. def generate_greedy(
  73. self,
  74. prompts: List[str],
  75. max_tokens: int,
  76. ) -> List[Tuple[List[int], str]]:
  77. outputs = self.generate(prompts,
  78. do_sample=False,
  79. max_new_tokens=max_tokens)
  80. for i in range(len(outputs)):
  81. output_ids, output_str = outputs[i]
  82. outputs[i] = (output_ids[0], output_str[0])
  83. return outputs
  84. def generate_beam_search(
  85. self,
  86. prompts: List[str],
  87. beam_width: int,
  88. max_tokens: int,
  89. ) -> List[Tuple[List[int], str]]:
  90. outputs = self.generate(prompts,
  91. do_sample=False,
  92. max_new_tokens=max_tokens,
  93. num_beams=beam_width,
  94. num_return_sequences=beam_width)
  95. for i in range(len(outputs)):
  96. output_ids, output_str = outputs[i]
  97. for j in range(len(output_ids)):
  98. output_ids[j] = [
  99. x for x in output_ids[j]
  100. if x != self.tokenizer.pad_token_id
  101. ]
  102. outputs[i] = (output_ids, output_str)
  103. return outputs
  104. def generate_greedy_logprobs(
  105. self,
  106. prompts: List[str],
  107. max_tokens: int,
  108. ) -> List[List[torch.Tensor]]:
  109. all_logprobs = []
  110. for prompt in prompts:
  111. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  112. output = self.model.generate(
  113. input_ids.cuda(),
  114. use_cache=True,
  115. do_sample=False,
  116. max_new_tokens=max_tokens,
  117. output_hidden_states=True,
  118. return_dict_in_generate=True,
  119. )
  120. seq_logprobs = []
  121. for hidden_states in output.hidden_states:
  122. last_hidden_states = hidden_states[-1][0]
  123. logits = torch.matmul(
  124. last_hidden_states,
  125. self.model.get_output_embeddings().weight.t(),
  126. )
  127. if self.model.get_output_embeddings().bias is not None:
  128. logits += self.model.get_output_embeddings(
  129. ).bias.unsqueeze(0)
  130. logprobs = torch.nn.functional.log_softmax(logits,
  131. dim=-1,
  132. dtype=torch.float32)
  133. seq_logprobs.append(logprobs)
  134. all_logprobs.append(seq_logprobs)
  135. return all_logprobs
  136. @pytest.fixture
  137. def hf_runner():
  138. return HfRunner
  139. class AphroditeRunner:
  140. def __init__(
  141. self,
  142. model_name: str,
  143. tokenizer_name: Optional[str] = None,
  144. dtype: str = "half",
  145. ) -> None:
  146. self.model = LLM(
  147. model=model_name,
  148. tokenizer=tokenizer_name,
  149. trust_remote_code=True,
  150. dtype=dtype,
  151. swap_space=0,
  152. )
  153. def generate(
  154. self,
  155. prompts: List[str],
  156. sampling_params: SamplingParams,
  157. ) -> List[Tuple[List[int], str]]:
  158. req_outputs = self.model.generate(prompts,
  159. sampling_params=sampling_params)
  160. outputs = []
  161. for req_output in req_outputs:
  162. prompt_str = req_output.prompt
  163. prompt_ids = req_output.prompt_token_ids
  164. req_sample_output_ids = []
  165. req_sample_output_strs = []
  166. for sample in req_output.outputs:
  167. output_str = sample.text
  168. output_ids = sample.token_ids
  169. req_sample_output_ids.append(prompt_ids + output_ids)
  170. req_sample_output_strs.append(prompt_str + output_str)
  171. outputs.append((req_sample_output_ids, req_sample_output_strs))
  172. return outputs
  173. def generate_greedy(
  174. self,
  175. prompts: List[str],
  176. max_tokens: int,
  177. ) -> List[Tuple[List[int], str]]:
  178. greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
  179. outputs = self.generate(prompts, greedy_params)
  180. return [(output_ids[0], output_str[0])
  181. for output_ids, output_str in outputs]
  182. def generate_beam_search(
  183. self,
  184. prompts: List[str],
  185. beam_width: int,
  186. max_tokens: int,
  187. ) -> List[Tuple[List[int], str]]:
  188. beam_search_params = SamplingParams(n=beam_width,
  189. use_beam_search=True,
  190. temperature=0.0,
  191. max_tokens=max_tokens)
  192. outputs = self.generate(prompts, beam_search_params)
  193. return outputs
  194. @pytest.fixture
  195. def aphrodite_runner():
  196. return AphroditeRunner