|
@@ -4,9 +4,35 @@
|
|
|
|
|
|
from functools import lru_cache
|
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
import triton
|
|
|
-from scipy import sparse
|
|
|
+
|
|
|
+
|
|
|
+class csr_matrix:
|
|
|
+ """Simple implementation of CSR matrix conversion without scipy.
|
|
|
+ This replaced scipy.sparse.csr_matrix() previously used."""
|
|
|
+
|
|
|
+ def __init__(self, input_array):
|
|
|
+ if not isinstance(input_array, np.ndarray):
|
|
|
+ raise ValueError("Input must be a NumPy array")
|
|
|
+
|
|
|
+ self.shape = input_array.shape
|
|
|
+ rows, cols = self.shape
|
|
|
+ data = []
|
|
|
+ indices = []
|
|
|
+ indptr = [0]
|
|
|
+
|
|
|
+ for i in range(rows):
|
|
|
+ for j in range(cols):
|
|
|
+ if input_array[i, j]:
|
|
|
+ data.append(input_array[i, j])
|
|
|
+ indices.append(j)
|
|
|
+ indptr.append(len(indices))
|
|
|
+
|
|
|
+ self.data = np.array(data)
|
|
|
+ self.indices = np.array(indices)
|
|
|
+ self.indptr = np.array(indptr)
|
|
|
|
|
|
|
|
|
def dense_to_crow_col(x: torch.Tensor):
|
|
@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
|
|
|
assert x.dim() in (2, 3)
|
|
|
if x.dim() == 2:
|
|
|
x = x[None]
|
|
|
- x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
|
|
+ x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
|
|
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
|
|
cols = [torch.from_numpy(xi.indices) for xi in x]
|
|
|
max_cols = max(len(xi) for xi in cols)
|