conftest.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import os
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. import torch
  5. from transformers import AutoModelForCausalLM
  6. from aphrodite import LLM, SamplingParams
  7. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  8. _TEST_DIR = os.path.dirname(__file__)
  9. _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
  10. _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
  11. def _read_prompts(filename: str) -> List[str]:
  12. with open(filename, "r") as f:
  13. prompts = f.readlines()
  14. return prompts
  15. @pytest.fixture
  16. def example_prompts() -> List[str]:
  17. prompts = []
  18. for filename in _TEST_PROMPTS:
  19. prompts += _read_prompts(filename)
  20. return prompts
  21. @pytest.fixture
  22. def example_long_prompts() -> List[str]:
  23. prompts = []
  24. for filename in _LONG_PROMPTS:
  25. prompts += _read_prompts(filename)
  26. return prompts
  27. _STR_DTYPE_TO_TORCH_DTYPE = {
  28. "half": torch.half,
  29. "bfloat16": torch.bfloat16,
  30. "float": torch.float,
  31. }
  32. class HfRunner:
  33. def __init__(
  34. self,
  35. model_name: str,
  36. tokenizer_name: Optional[str] = None,
  37. dtype: str = "half",
  38. ) -> None:
  39. assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
  40. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  41. self.model = AutoModelForCausalLM.from_pretrained(
  42. model_name,
  43. torch_dtype=torch_dtype,
  44. trust_remote_code=True,
  45. ).cuda()
  46. if tokenizer_name is None:
  47. tokenizer_name = model_name
  48. self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
  49. def generate(
  50. self,
  51. prompts: List[str],
  52. **kwargs,
  53. ) -> List[Tuple[List[int], str]]:
  54. outputs: List[Tuple[List[int], str]] = []
  55. for prompt in prompts:
  56. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  57. output_ids = self.model.generate(
  58. input_ids.cuda(),
  59. use_cache=True,
  60. **kwargs,
  61. )
  62. output_str = self.tokenizer.batch_decode(
  63. output_ids,
  64. skip_special_tokens=True,
  65. clean_up_tokenization_spaces=False,
  66. )
  67. output_ids = output_ids.cpu().tolist()
  68. outputs.append((output_ids, output_str))
  69. return outputs
  70. def generate_greedy(
  71. self,
  72. prompts: List[str],
  73. max_tokens: int,
  74. ) -> List[Tuple[List[int], str]]:
  75. outputs = self.generate(prompts,
  76. do_sample=False,
  77. max_new_tokens=max_tokens)
  78. for i in range(len(outputs)):
  79. output_ids, output_str = outputs[i]
  80. outputs[i] = (output_ids[0], output_str[0])
  81. return outputs
  82. def generate_beam_search(
  83. self,
  84. prompts: List[str],
  85. beam_width: int,
  86. max_tokens: int,
  87. ) -> List[Tuple[List[int], str]]:
  88. outputs = self.generate(prompts,
  89. do_sample=False,
  90. max_new_tokens=max_tokens,
  91. num_beams=beam_width,
  92. num_return_sequences=beam_width)
  93. for i in range(len(outputs)):
  94. output_ids, output_str = outputs[i]
  95. for j in range(len(output_ids)):
  96. output_ids[j] = [
  97. x for x in output_ids[j]
  98. if x != self.tokenizer.pad_token_id
  99. ]
  100. outputs[i] = (output_ids, output_str)
  101. return outputs
  102. def generate_greedy_logprobs(
  103. self,
  104. prompts: List[str],
  105. max_tokens: int,
  106. ) -> List[List[torch.Tensor]]:
  107. all_logprobs = []
  108. for prompt in prompts:
  109. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  110. output = self.model.generate(
  111. input_ids.cuda(),
  112. use_cache=True,
  113. do_sample=False,
  114. max_new_tokens=max_tokens,
  115. output_hidden_states=True,
  116. return_dict_in_generate=True,
  117. )
  118. seq_logprobs = []
  119. for hidden_states in output.hidden_states:
  120. last_hidden_states = hidden_states[-1][0]
  121. logits = torch.matmul(
  122. last_hidden_states,
  123. self.model.get_output_embeddings().weight.t(),
  124. )
  125. if self.model.get_output_embeddings().bias is not None:
  126. logits += self.model.get_output_embeddings(
  127. ).bias.unsqueeze(0)
  128. logprobs = torch.nn.functional.log_softmax(logits,
  129. dim=-1,
  130. dtype=torch.float32)
  131. seq_logprobs.append(logprobs)
  132. all_logprobs.append(seq_logprobs)
  133. return all_logprobs
  134. @pytest.fixture
  135. def hf_runner():
  136. return HfRunner
  137. class AphroditeRunner:
  138. def __init__(
  139. self,
  140. model_name: str,
  141. tokenizer_name: Optional[str] = None,
  142. dtype: str = "half",
  143. disable_log_stats: bool = True,
  144. tensor_parallel_size: int = 1,
  145. **kwargs,
  146. ) -> None:
  147. self.model = LLM(
  148. model=model_name,
  149. tokenizer=tokenizer_name,
  150. trust_remote_code=True,
  151. dtype=dtype,
  152. swap_space=0,
  153. disable_log_stats=disable_log_stats,
  154. tensor_parallel_size=tensor_parallel_size,
  155. **kwargs,
  156. )
  157. def generate(
  158. self,
  159. prompts: List[str],
  160. sampling_params: SamplingParams,
  161. ) -> List[Tuple[List[int], str]]:
  162. req_outputs = self.model.generate(prompts,
  163. sampling_params=sampling_params)
  164. outputs = []
  165. for req_output in req_outputs:
  166. prompt_str = req_output.prompt
  167. prompt_ids = req_output.prompt_token_ids
  168. req_sample_output_ids = []
  169. req_sample_output_strs = []
  170. for sample in req_output.outputs:
  171. output_str = sample.text
  172. output_ids = sample.token_ids
  173. req_sample_output_ids.append(prompt_ids + output_ids)
  174. req_sample_output_strs.append(prompt_str + output_str)
  175. outputs.append((req_sample_output_ids, req_sample_output_strs))
  176. return outputs
  177. def generate_w_logprobs(
  178. self,
  179. prompts: List[str],
  180. sampling_params: SamplingParams,
  181. ) -> List[Tuple[List[int], str]]:
  182. assert sampling_params.logprobs is not None
  183. req_outputs = self.model.generate(prompts,
  184. sampling_params=sampling_params)
  185. outputs = []
  186. for req_output in req_outputs:
  187. for sample in req_output.outputs:
  188. output_str = sample.text
  189. output_ids = sample.token_ids
  190. output_logprobs = sample.logprobs
  191. outputs.append((output_ids, output_str, output_logprobs))
  192. return outputs
  193. def generate_greedy(
  194. self,
  195. prompts: List[str],
  196. max_tokens: int,
  197. ) -> List[Tuple[List[int], str]]:
  198. greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
  199. outputs = self.generate(prompts, greedy_params)
  200. return [(output_ids[0], output_str[0])
  201. for output_ids, output_str in outputs]
  202. def generate_greedy_logprobs(
  203. self,
  204. prompts: List[str],
  205. max_tokens: int,
  206. num_logprobs: int,
  207. ) -> List[Tuple[List[int], str]]:
  208. greedy_logprobs_params = SamplingParams(temperature=0.0,
  209. max_tokens=max_tokens,
  210. logprobs=num_logprobs)
  211. outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
  212. return [(output_ids, output_str, output_logprobs)
  213. for output_ids, output_str, output_logprobs in outputs]
  214. def generate_beam_search(
  215. self,
  216. prompts: List[str],
  217. beam_width: int,
  218. max_tokens: int,
  219. ) -> List[Tuple[List[int], str]]:
  220. beam_search_params = SamplingParams(n=beam_width,
  221. use_beam_search=True,
  222. temperature=0.0,
  223. max_tokens=max_tokens)
  224. outputs = self.generate(prompts, beam_search_params)
  225. return outputs
  226. @pytest.fixture
  227. def aphrodite_runner():
  228. return AphroditeRunner