Kaynağa Gözat

fix: remove scipy and re-implement CSR matrix

AlpinDale 6 ay önce
ebeveyn
işleme
a2d476183f

+ 28 - 2
aphrodite/attention/ops/blocksparse_attention/utils.py

@@ -4,9 +4,35 @@
 
 from functools import lru_cache
 
+import numpy as np
 import torch
 import triton
-from scipy import sparse
+
+
+class csr_matrix:
+    """Simple implementation of CSR matrix conversion without scipy.
+    This replaced scipy.sparse.csr_matrix() previously used."""
+
+    def __init__(self, input_array):
+        if not isinstance(input_array, np.ndarray):
+            raise ValueError("Input must be a NumPy array")
+
+        self.shape = input_array.shape
+        rows, cols = self.shape
+        data = []
+        indices = []
+        indptr = [0]
+
+        for i in range(rows):
+            for j in range(cols):
+                if input_array[i, j]:
+                    data.append(input_array[i, j])
+                    indices.append(j)
+            indptr.append(len(indices))
+
+        self.data = np.array(data)
+        self.indices = np.array(indices)
+        self.indptr = np.array(indptr)
 
 
 def dense_to_crow_col(x: torch.Tensor):
@@ -19,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
     assert x.dim() in (2, 3)
     if x.dim() == 2:
         x = x[None]
-    x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
+    x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
     crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
     cols = [torch.from_numpy(xi.indices) for xi in x]
     max_cols = max(len(xi) for xi in cols)