Fix memory pinning conditional
@@ -257,7 +257,7 @@ class SamplingTensors:
# Note that the performance will be very bad without pinned memory.
# Pinned memory allows non-blocking transfers to device.
- pin_memory = not in_wsl() or is_neuron()
+ pin_memory = not in_wsl() and not is_neuron()
def _tensor(contents: list, dtype) -> torch.Tensor:
loc_t = torch.tensor(contents,