conftest.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from typing import List, Optional, Tuple
  2. import pytest
  3. import torch
  4. from transformers import AutoModelForCausalLM
  5. from aphrodite import LLM, SamplingParams
  6. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  7. _TEST_PROMPTS = [
  8. # pylint: disable=line-too-long
  9. "Develop a detailed method for integrating a blockchain-based distributed ledger system into a pre-existing finance management application. The focus should be on ensuring security, transparency, and real-time updates of transactions.",
  10. "Design an AI-powered predictive analytics engine capable of identifying trends and patterns from unstructured data sets. The engine should be adaptable to different industry requirements such as healthcare, finance, and marketing.",
  11. "Construct a comprehensive model for a multi-cloud architecture that can smoothly transition between different cloud platforms (AWS, Google Cloud, Azure) without any interruption in service or loss of data.",
  12. "Propose a strategy for integrating Quantum Computing capabilities into existing high-performance computing (HPC) systems. The approach should consider potential challenges and solutions of Quantum-HPC integration.",
  13. "Create a robust cybersecurity framework for an Internet of Things (IoT) ecosystem. The framework should be capable of detecting, preventing, and mitigating potential security breaches.",
  14. "Develop a scalable high-frequency trading algorithm that uses machine learning to predict and respond to microtrends in financial markets. The algorithm should be capable of processing real-time data and executing trades within milliseconds.",
  15. "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
  16. ]
  17. @pytest.fixture
  18. def example_prompts() -> List[str]:
  19. return _TEST_PROMPTS
  20. _STR_DTYPE_TO_TORCH_DTYPE = {
  21. "half": torch.half,
  22. "bfloat16": torch.bfloat16,
  23. "float": torch.float,
  24. }
  25. class HfRunner:
  26. def __init__(
  27. self,
  28. model_name: str,
  29. tokenizer_name: Optional[str] = None,
  30. dtype: str = "half",
  31. ) -> None:
  32. assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
  33. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  34. self.model = AutoModelForCausalLM.from_pretrained(
  35. model_name,
  36. torch_dtype=torch_dtype,
  37. trust_remote_code=True,
  38. ).cuda()
  39. if tokenizer_name is None:
  40. tokenizer_name = model_name
  41. self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
  42. def generate(
  43. self,
  44. prompts: List[str],
  45. **kwargs,
  46. ) -> List[Tuple[List[int], str]]:
  47. outputs: List[Tuple[List[int], str]] = []
  48. for prompt in prompts:
  49. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  50. output_ids = self.model.generate(
  51. input_ids.cuda(),
  52. use_cache=True,
  53. **kwargs,
  54. )
  55. output_str = self.tokenizer.batch_decode(
  56. output_ids,
  57. skip_special_tokens=True,
  58. clean_up_tokenization_spaces=False,
  59. )
  60. output_ids = output_ids.cpu().tolist()
  61. outputs.append((output_ids, output_str))
  62. return outputs
  63. def generate_greedy(
  64. self,
  65. prompts: List[str],
  66. max_tokens: int,
  67. ) -> List[Tuple[List[int], str]]:
  68. outputs = self.generate(prompts,
  69. do_sample=False,
  70. max_new_tokens=max_tokens)
  71. for i in range(len(outputs)):
  72. output_ids, output_str = outputs[i]
  73. outputs[i] = (output_ids[0], output_str[0])
  74. return outputs
  75. def generate_beam_search(
  76. self,
  77. prompts: List[str],
  78. beam_width: int,
  79. max_tokens: int,
  80. ) -> List[Tuple[List[int], str]]:
  81. outputs = self.generate(prompts,
  82. do_sample=False,
  83. max_new_tokens=max_tokens,
  84. num_beams=beam_width,
  85. num_return_sequences=beam_width)
  86. for i in range(len(outputs)):
  87. output_ids, output_str = outputs[i]
  88. for j in range(len(output_ids)):
  89. output_ids[j] = [
  90. x for x in output_ids[j]
  91. if x != self.tokenizer.pad_token_id
  92. ]
  93. outputs[i] = (output_ids, output_str)
  94. return outputs
  95. def generate_greedy_logprobs(
  96. self,
  97. prompts: List[str],
  98. max_tokens: int,
  99. ) -> List[List[torch.Tensor]]:
  100. all_logprobs = []
  101. for prompt in prompts:
  102. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  103. output = self.model.generate(
  104. input_ids.cuda(),
  105. use_cache=True,
  106. do_sample=False,
  107. max_new_tokens=max_tokens,
  108. output_hidden_states=True,
  109. return_dict_in_generate=True,
  110. )
  111. seq_logprobs = []
  112. for hidden_states in output.hidden_states:
  113. last_hidden_states = hidden_states[-1][0]
  114. logits = torch.matmul(
  115. last_hidden_states,
  116. self.model.get_output_embeddings().weight.t(),
  117. )
  118. if self.model.get_output_embeddings().bias is not None:
  119. logits += self.model.get_output_embeddings(
  120. ).bias.unsqueeze(0)
  121. logprobs = torch.nn.functional.log_softmax(logits,
  122. dim=-1,
  123. dtype=torch.float32)
  124. seq_logprobs.append(logprobs)
  125. all_logprobs.append(seq_logprobs)
  126. return all_logprobs
  127. @pytest.fixture
  128. def hf_runner():
  129. return HfRunner
  130. class AphroditeRunner:
  131. def __init__(
  132. self,
  133. model_name: str,
  134. tokenizer_name: Optional[str] = None,
  135. dtype: str = "half",
  136. ) -> None:
  137. self.model = LLM(
  138. model=model_name,
  139. tokenizer=tokenizer_name,
  140. trust_remote_code=True,
  141. dtype=dtype,
  142. swap_space=0,
  143. )
  144. def generate(
  145. self,
  146. prompts: List[str],
  147. sampling_params: SamplingParams,
  148. ) -> List[Tuple[List[int], str]]:
  149. req_outputs = self.model.generate(prompts,
  150. sampling_params=sampling_params)
  151. outputs = []
  152. for req_output in req_outputs:
  153. prompt_str = req_output.prompt
  154. prompt_ids = req_output.prompt_token_ids
  155. req_sample_output_ids = []
  156. req_sample_output_strs = []
  157. for sample in req_output.outputs:
  158. output_str = sample.text
  159. output_ids = sample.token_ids
  160. req_sample_output_ids.append(prompt_ids + output_ids)
  161. req_sample_output_strs.append(prompt_str + output_str)
  162. outputs.append((req_sample_output_ids, req_sample_output_strs))
  163. return outputs
  164. def generate_greedy(
  165. self,
  166. prompts: List[str],
  167. max_tokens: int,
  168. ) -> List[Tuple[List[int], str]]:
  169. greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
  170. outputs = self.generate(prompts, greedy_params)
  171. return [(output_ids[0], output_str[0])
  172. for output_ids, output_str in outputs]
  173. def generate_beam_search(
  174. self,
  175. prompts: List[str],
  176. beam_width: int,
  177. max_tokens: int,
  178. ) -> List[Tuple[List[int], str]]:
  179. beam_search_params = SamplingParams(n=beam_width,
  180. use_beam_search=True,
  181. temperature=0.0,
  182. max_tokens=max_tokens)
  183. outputs = self.generate(prompts, beam_search_params)
  184. return outputs
  185. @pytest.fixture
  186. def aphrodite_runner():
  187. return AphroditeRunner