Quellcode durchsuchen

fix: avoid copying prompt/output tokens if penalties arent used

AlpinDale vor 7 Monaten
Ursprung
Commit
b9a5a0ae79
1 geänderte Dateien mit 46 neuen und 28 gelöschten Zeilen
  1. 46 28
      aphrodite/modeling/sampling_metadata.py

+ 46 - 28
aphrodite/modeling/sampling_metadata.py

@@ -433,8 +433,9 @@ class SamplingTensors:
                 typical_ps += [1] * prefill_len
                 typical_ps += [1] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_factors += [smoothing_factor] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
                 smoothing_curves += [smoothing_curve] * prefill_len
-                prompt_tokens.extend([] for _ in range(prefill_len))
-                output_tokens.extend([] for _ in range(prefill_len))
+                if do_penalties:
+                    prompt_tokens.extend([] for _ in range(prefill_len))
+                    output_tokens.extend([] for _ in range(prefill_len))
 
 
             if seq_group.do_sample:
             if seq_group.do_sample:
                 sample_lens = len(seq_group.sample_indices)
                 sample_lens = len(seq_group.sample_indices)
@@ -502,18 +503,21 @@ class SamplingTensors:
         # Note that the performance will be very bad without
         # Note that the performance will be very bad without
         # pinned memory.
         # pinned memory.
         pin_memory = is_pin_memory_available()
         pin_memory = is_pin_memory_available()
-        prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
-                             default=0)
-        prompt_padded_tokens = [
-            tokens + [vocab_size] * (prompt_max_len - len(tokens))
-            for tokens in prompt_tokens
-        ]
-        output_max_len = max([len(tokens) for tokens in output_tokens],
-                             default=0)
-        output_padded_tokens = [
-            tokens + [vocab_size] * (output_max_len - len(tokens))
-            for tokens in output_tokens
-        ]
+        do_penalties = prompt_tokens or output_tokens
+
+        if do_penalties:
+            prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
+                                 default=0)
+            prompt_padded_tokens = [
+                tokens + [vocab_size] * (prompt_max_len - len(tokens))
+                for tokens in prompt_tokens
+            ]
+            output_max_len = max([len(tokens) for tokens in output_tokens],
+                                 default=0)
+            output_padded_tokens = [
+                tokens + [vocab_size] * (output_max_len - len(tokens))
+                for tokens in output_tokens
+            ]
 
 
         temperatures_t = torch.tensor(
         temperatures_t = torch.tensor(
             temperatures,
             temperatures,
@@ -591,18 +595,22 @@ class SamplingTensors:
             dtype=torch.long,
             dtype=torch.long,
             pin_memory=pin_memory,
             pin_memory=pin_memory,
         )
         )
-        prompt_tensor = torch.tensor(
-            prompt_padded_tokens,
-            device="cpu",
-            dtype=torch.long,
-            pin_memory=pin_memory,
-        )
-        output_tensor = torch.tensor(
-            output_padded_tokens,
-            device="cpu",
-            dtype=torch.long,
-            pin_memory=pin_memory,
-        )
+        if do_penalties:
+            prompt_tensor = torch.tensor(
+                prompt_padded_tokens,
+                device="cpu",
+                dtype=torch.long,
+                pin_memory=pin_memory,
+            )
+            output_tensor = torch.tensor(
+                output_padded_tokens,
+                device="cpu",
+                dtype=torch.long,
+                pin_memory=pin_memory,
+            )
+        else:
+            prompt_tensor = None
+            output_tensor = None
         # need to transpose and make contiguous to
         # need to transpose and make contiguous to
         # copy the tensor correctly.
         # copy the tensor correctly.
         # [batch_size, n_seeds] -> [n_seeds, batch_size]
         # [batch_size, n_seeds] -> [n_seeds, batch_size]
@@ -625,6 +633,16 @@ class SamplingTensors:
             extra_seeds_gpu = None
             extra_seeds_gpu = None
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
         sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
 
 
+        if do_penalties:
+            prompt_tokens_gpu = prompt_tensor.to(device=device,
+                                                 non_blocking=True)
+            output_tokens_gpu = output_tensor.to(device=device,
+                                                 non_blocking=True)
+        else:
+            empty_tensor = torch.empty(0, device=device, dtype=torch.long)
+            prompt_tokens_gpu = empty_tensor
+            output_tokens_gpu = empty_tensor
+
         return cls(
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
             top_ps=top_ps_t.to(device=device, non_blocking=True),
@@ -646,8 +664,8 @@ class SamplingTensors:
             smoothing_curves=smoothing_curves_t.to(device=device,
             smoothing_curves=smoothing_curves_t.to(device=device,
                                                    non_blocking=True),
                                                    non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
-            prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
-            output_tokens=output_tensor.to(device=device, non_blocking=True),
+            prompt_tokens=prompt_tokens_gpu,
+            output_tokens=output_tokens_gpu,
             sampling_seeds=sampling_seeds_gpu,
             sampling_seeds=sampling_seeds_gpu,
             sample_indices=sample_indices_t.to(device=device,
             sample_indices=sample_indices_t.to(device=device,
                                                non_blocking=True),
                                                non_blocking=True),