Prechádzať zdrojové kódy

fix: sampler test with new transformers version

AlpinDale 4 mesiacov pred
rodič
commit
22429e4a10
1 zmenil súbory, kde vykonal 15 pridanie a 5 odobranie
  1. 15 5
      tests/samplers/test_sampler.py

+ 15 - 5
tests/samplers/test_sampler.py

@@ -1,6 +1,7 @@
 import itertools
 import random
 from array import array
+from dataclasses import dataclass
 from typing import Dict, List, Optional, Tuple
 from unittest.mock import Mock, patch
 
@@ -600,8 +601,17 @@ def test_sampler_top_k_top_p(seed: int, device: str):
     generation_config = GenerationConfig(top_k=top_k,
                                          top_p=top_p,
                                          do_sample=True)
-    warpers = generation_model._get_logits_warper(generation_config, device)
-    assert len(warpers) == 2  # top_p and top_k
+    @dataclass
+    class MockConfig:
+        is_encoder_decoder: bool = False
+    generation_model.config = MockConfig()  # needed by the following method
+    generation_model._prepare_special_tokens(generation_config, device=device)
+    processors = generation_model._get_logits_processor(generation_config,
+                                                        None,
+                                                        None,
+                                                        None, [],
+                                                        device=device)
+    assert len(processors) == 2  # top_p and top_k
 
     seq_group_metadata_list: List[SequenceGroupMetadata] = []
     seq_lens: List[int] = []
@@ -638,12 +648,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
         return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
                  for prob in probs], None)
 
-    with patch("aphrodite.model_executor.layers.sampler._sample", mock_sample):
+    with patch("aphrodite.modeling.layers.sampler._sample", mock_sample):
         sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
 
     assert sample_probs is not None
 
-    hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
+    hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
     hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
     torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
     assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
@@ -735,7 +745,7 @@ def test_sampler_include_gpu_probs_tensor(device: str):
 
     mock_inplace = Mock()
     with patch(
-            "aphrodite.model_executor.layers.sampler._modify_greedy_probs_inplace",
+            "aphrodite.modeling.layers.sampler._modify_greedy_probs_inplace",
             mock_inplace):
 
         sampler_output = _do_sample(batch_size, fake_logits, sampler,