punica.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.triton_utils import HAS_TRITON
  10. if HAS_TRITON:
  11. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  12. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  13. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  14. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  15. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  16. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  17. if TYPE_CHECKING:
  18. # avoid circuit import
  19. from aphrodite.lora.layers import LoRAMapping
  20. from aphrodite.lora.models import LongContextLoRAContext
  21. def compute_meta(
  22. token_lora_tensor: torch.Tensor
  23. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
  24. """
  25. Get the information required for the sgmv kernel. With the features:
  26. 1. If consecutive requests in the batch use the same LoRA, this function
  27. will combine them into a single request, improving sgmv kernel inference
  28. performance.
  29. 2. At the beginning of each prefill stage inference, recalculations are
  30. needed based on the input, but only once.
  31. """
  32. lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  33. token_lora_tensor, return_counts=True)
  34. cum_result = torch.cumsum(seq_length_tensor, dim=0)
  35. b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
  36. b_seq_start_tensor[1:].copy_(cum_result[:-1])
  37. max_length = seq_length_tensor.max().item()
  38. token_nums = seq_length_tensor.sum().item()
  39. batch_size = lora_indices_tensor.size(0)
  40. no_lora = False
  41. # -1 means no lora should be applied. Use `no_lora` to determine whether
  42. # the current step requires LoRA. If LoRA is not needed, the prefill stage
  43. # does not need to launch the triton kernel, which can improve performance
  44. if batch_size == 1 and lora_indices_tensor == -1:
  45. no_lora = True
  46. return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  47. batch_size, max_length, token_nums, no_lora)
  48. # TODO see if this can be vectorized
  49. def convert_mapping(
  50. mapping: "LoRAMapping",
  51. lora_index_to_id: List[Optional[int]],
  52. max_loras: int,
  53. vocab_size: int,
  54. extra_vocab_size: int,
  55. long_lora_context: Optional["LongContextLoRAContext"] = None,
  56. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  57. Optional[torch.Tensor], List[int]]:
  58. """Converts LoRAMapping to index tensors.
  59. Args:
  60. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  61. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  62. max_loras: Maximum number of LoRAs.
  63. vocab_size: Model vocab size.
  64. extra_vocab_size: Extra vocab size each LoRA can have.
  65. long_lora_context: Passed if there are long context lora in a batch.
  66. Returns:
  67. A tuple of tensors:
  68. base_indices: Tensor of shape [batch_size] mapping batch rows to
  69. LoRA indices.
  70. sampler_indices: Tensor of shape [batch_size] mapping requests to
  71. LoRA indices for sampler. For generation, this will be the
  72. same as base_indicies. For prefill, this will map requests
  73. to LoRA indices.
  74. sampler_indices_padded: Tensor of shape [batch_size] mapping
  75. requests to LoRA indices for sampler with padding.
  76. Same as sampler_indicies, but -1 is replaced with
  77. max_loras.
  78. embeddings_indices: Tensor of shape [2, batch_size] mapping
  79. requests to embedding indices. First row is for embeddings
  80. added by the LoRAs, second row is for the LoRA.lora_a
  81. embeddings.
  82. long_lora_indices: Tensor of shape [batch_size] mapping
  83. requests to RoPE offsets and rot dims for long LoRAs.
  84. None if long context lora doesn't exist.
  85. indices_len: List of lengths of the above tensors. It contains
  86. (base_indices, sampler_indices, sampler_indices_padded,
  87. embeddings_indices, long_lora_indices).
  88. """
  89. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  90. embedding_indices = index_mapping_indices.copy()
  91. lora_indices = index_mapping_indices.copy()
  92. long_lora_offsets: Optional[torch.Tensor] = None
  93. if long_lora_context:
  94. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  95. device="cuda",
  96. dtype=torch.long)
  97. prompt_mapping: List[int] = [
  98. lora_index_to_id.index(x) if x > 0 else -1
  99. for x in mapping.prompt_mapping
  100. ]
  101. lora_idx = None
  102. for i in range(len(index_mapping_indices)):
  103. # TODO index can be slow. optimize
  104. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  105. if index_mapping_indices[i] > 0 else -1)
  106. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  107. lora_indices[i] = lora_idx
  108. if long_lora_context:
  109. assert long_lora_offsets is not None
  110. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  111. index_mapping_indices[i], 0)
  112. long_lora_offsets[i] = lora_offset
  113. indices_list: List[Union[List[int], torch.Tensor]] = [
  114. index_mapping_indices,
  115. lora_indices,
  116. embedding_indices,
  117. ]
  118. if long_lora_context:
  119. assert long_lora_offsets is not None
  120. indices_list.append(long_lora_offsets)
  121. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  122. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  123. device="cuda",
  124. dtype=torch.long)
  125. embeddings_indices = torch.stack([
  126. indices[2] * extra_vocab_size,
  127. indices[2] * (vocab_size + extra_vocab_size),
  128. ])
  129. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  130. base_indices = indices[1]
  131. sampler_indices = prompt_mapping_tensor
  132. sampler_indices_padded = sampler_indices.clone()
  133. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  134. sampler_indices_padded = torch.arange(
  135. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
  136. sampler_indices_padded * len(sampler_indices_padded))
  137. long_lora_indices = None
  138. long_lora_indices_len: Optional[int] = None
  139. if long_lora_context:
  140. long_lora_indices = indices[3]
  141. long_lora_indices_len = long_lora_indices.shape[-1]
  142. # Contain length of indices tensors. Used to index into each tensor.
  143. indices_len = [
  144. base_indices.shape[-1],
  145. sampler_indices.shape[-1],
  146. sampler_indices_padded.shape[-1],
  147. embeddings_indices.shape[-1],
  148. ]
  149. if long_lora_indices_len is not None:
  150. indices_len.append(long_lora_indices_len)
  151. else:
  152. # If long_lora doesn't exist,append None
  153. indices_len.append(None)
  154. return (
  155. base_indices,
  156. sampler_indices,
  157. sampler_indices_padded,
  158. embeddings_indices,
  159. long_lora_indices,
  160. indices_len,
  161. )
  162. class PunicaWrapper:
  163. """
  164. PunicaWrapper is designed to manage and provide metadata for the punica
  165. kernel. The main function is to maintain the state information for
  166. Multi-LoRA, and to provide the interface for the punica kernel.
  167. """
  168. def __init__(self, max_num_batched_tokens: int, max_batches: int,
  169. device: str):
  170. self._token_lora_indices = torch.empty(max_num_batched_tokens,
  171. dtype=torch.long,
  172. device=device)
  173. self._sampler_indices = torch.empty(max_num_batched_tokens,
  174. dtype=torch.long,
  175. device=device)
  176. self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
  177. dtype=torch.long,
  178. device=device)
  179. self._embeddings_indices = torch.empty(2,
  180. max_num_batched_tokens,
  181. dtype=torch.long,
  182. device=device)
  183. self._long_lora_indices = torch.empty(max_num_batched_tokens,
  184. dtype=torch.long,
  185. device=device)
  186. # 5 is the number of indicies tensors.
  187. # base_indices, sampler_indices, sampler_indices_padded,
  188. # embeddings_indices,long_lora_indices
  189. self.indices_len: List[Optional[int]] = [None] * 5
  190. # these attributes are the information required for sgmv kernel
  191. self._seq_start_locs = torch.empty(max_batches,
  192. dtype=torch.long,
  193. device=device)
  194. self._seq_lengths = torch.empty(max_batches,
  195. dtype=torch.long,
  196. device=device)
  197. self._lora_indices_per_batch = torch.empty(max_batches,
  198. dtype=torch.long,
  199. device=device)
  200. self.max_length: int = 0
  201. self.token_nums: int = 0
  202. self.batch_size: int = -1
  203. self.is_prefill = False
  204. self.no_lora = False
  205. def update_metadata(
  206. self,
  207. mapping: "LoRAMapping",
  208. lora_index_to_id: List[Optional[int]],
  209. max_loras: int,
  210. vocab_size: int,
  211. extra_vocab_size: int,
  212. long_lora_context: Optional["LongContextLoRAContext"] = None,
  213. ):
  214. self._update_base_metadata(mapping, lora_index_to_id, max_loras,
  215. vocab_size, extra_vocab_size,
  216. long_lora_context)
  217. if mapping.is_prefill:
  218. # Update metadata required for prefill-related operators.
  219. self._update_prefill_metada(self.token_lora_indices)
  220. self.is_prefill = True
  221. else:
  222. self.is_prefill = False
  223. def _update_base_metadata(
  224. self,
  225. mapping: "LoRAMapping",
  226. lora_index_to_id: List[Optional[int]],
  227. max_loras: int,
  228. vocab_size: int,
  229. extra_vocab_size: int,
  230. long_lora_context: Optional["LongContextLoRAContext"] = None,
  231. ):
  232. (
  233. base_indices,
  234. sampler_indices,
  235. sampler_indices_padded,
  236. embeddings_indices,
  237. long_lora_offsets_tensor,
  238. indices_len,
  239. ) = convert_mapping(
  240. mapping,
  241. lora_index_to_id,
  242. max_loras,
  243. vocab_size,
  244. extra_vocab_size,
  245. long_lora_context,
  246. )
  247. self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
  248. self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  249. self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  250. sampler_indices_padded)
  251. self._embeddings_indices[:embeddings_indices.
  252. shape[0], :embeddings_indices.shape[1]].copy_(
  253. embeddings_indices)
  254. if long_lora_offsets_tensor is not None:
  255. self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  256. long_lora_offsets_tensor)
  257. else:
  258. self._long_lora_indices.zero_()
  259. self.indices_len[:] = indices_len
  260. def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
  261. (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  262. batch_size, max_length, token_nums,
  263. no_lora) = compute_meta(token_lora_tensor)
  264. self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
  265. b_seq_start_tensor)
  266. self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
  267. self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
  268. lora_indices_tensor)
  269. self.batch_size = batch_size
  270. self.max_length = max_length
  271. self.token_nums = token_nums
  272. self.no_lora = no_lora
  273. @property
  274. def prefill_metadata(
  275. self
  276. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
  277. """
  278. This property provides a convenient way to access the necessary
  279. metadata for prefill-related kernel computations.
  280. 1. seq_start_locs: Tensor of sequence start positions.
  281. 2. seq_lengths: Tensor of sequence lengths.
  282. 3. lora_indices_per_batch: Tensor of lora indices, and an index of
  283. -1 means no lora should be applied.
  284. 4. batch_size: Batch size after clustering identical lora indices.
  285. 5. max_length: The maximum sequence length in the batch.
  286. 6. token_nums: The token numbers in the batch.
  287. """
  288. return (self._seq_start_locs[:self.batch_size],
  289. self._seq_lengths[:self.batch_size],
  290. self._lora_indices_per_batch[:self.batch_size],
  291. self.batch_size, self.max_length, self.token_nums)
  292. @property
  293. def token_lora_indices(self) -> torch.Tensor:
  294. """
  295. This property provides the lora indices corresponding to each token
  296. in the batch. An index of -1 means no lora should be applied.
  297. """
  298. token_lora_len = self.indices_len[0]
  299. return self._token_lora_indices[:token_lora_len]
  300. @property
  301. def sampler_indices(self) -> torch.Tensor:
  302. """
  303. This property is used to access the lora indices specifically for
  304. LogitsProcessorWithLoRA
  305. """
  306. sampler_indices_len = self.indices_len[1]
  307. return self._sampler_indices[:sampler_indices_len]
  308. @property
  309. def sampler_indices_padded(self) -> torch.Tensor:
  310. """
  311. This property provides access to padded sampler indices
  312. """
  313. indices_padded_len = self.indices_len[2]
  314. return self._sampler_indices_padded[:indices_padded_len]
  315. @property
  316. def embeddings_indices(self) -> torch.Tensor:
  317. """
  318. This property provides access to the indices used for lora embeddings,
  319. specifically for VocabParallelEmbeddingWithLoRA
  320. """
  321. embeddings_indices_len = self.indices_len[3]
  322. return self._embeddings_indices[:, :embeddings_indices_len]
  323. @property
  324. def long_lora_indices(self) -> torch.Tensor:
  325. """
  326. This property provides access to the indices used for long context
  327. lora, specifically for LinearScalingRotaryEmbeddingWithLora
  328. """
  329. long_lora_len = self.indices_len[4]
  330. return self._long_lora_indices[:long_lora_len]
  331. def shrink_prefill(
  332. self,
  333. y: torch.Tensor,
  334. x: torch.Tensor,
  335. w_t_all: torch.Tensor,
  336. scale: float,
  337. ):
  338. #No LoRA request, so return directly
  339. if self.no_lora:
  340. return
  341. sgmv_shrink(
  342. x,
  343. w_t_all,
  344. y,
  345. *self.prefill_metadata,
  346. scale,
  347. )
  348. def shrink_decode(
  349. self,
  350. y: torch.Tensor,
  351. x: torch.Tensor,
  352. w_t_all: torch.Tensor,
  353. scale: float,
  354. ):
  355. bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
  356. def expand_prefill(
  357. self,
  358. y: torch.Tensor,
  359. x: torch.Tensor,
  360. w_t_all: torch.Tensor,
  361. add_input: bool,
  362. ):
  363. #No LoRA request, so return directly
  364. if self.no_lora:
  365. return
  366. sgmv_expand(
  367. x,
  368. w_t_all,
  369. y,
  370. *self.prefill_metadata,
  371. add_input,
  372. )
  373. def expand_decode(
  374. self,
  375. y: torch.Tensor,
  376. x: torch.Tensor,
  377. w_t_all: torch.Tensor,
  378. add_input: bool,
  379. ):
  380. bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
  381. def expand_slice_prefill(
  382. self,
  383. y: torch.Tensor,
  384. x: torch.Tensor,
  385. w_t_all: torch.Tensor,
  386. y_offset: Optional[int],
  387. y_slice_size: Optional[int],
  388. add_input: bool,
  389. ):
  390. #No LoRA request, so return directly
  391. if self.no_lora:
  392. return
  393. sgmv_expand_slice(
  394. x,
  395. w_t_all,
  396. y,
  397. *self.prefill_metadata,
  398. y_offset,
  399. y_slice_size,
  400. add_input,
  401. )
  402. def expand_slice_decode(
  403. self,
  404. y: torch.Tensor,
  405. x: torch.Tensor,
  406. w_t_all: torch.Tensor,
  407. y_offset: Optional[int],
  408. y_slice_size: Optional[int],
  409. add_input: bool,
  410. ):
  411. bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
  412. y_slice_size, add_input)
  413. def add_shrink(
  414. self,
  415. y: torch.Tensor,
  416. x: torch.Tensor,
  417. w_t_all: torch.Tensor,
  418. scale: float,
  419. ):
  420. """
  421. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  422. GEMM of lora'a.
  423. When `is_prefill is` true, it indicates that it is currently the
  424. prefill stage, and the `shrink_prefill` function should be called.
  425. Otherwise, it is the decode stage, and the shrink_decode function
  426. should be called.
  427. """
  428. shrink_fun: Callable = (self.shrink_prefill
  429. if self.is_prefill else self.shrink_decode)
  430. shrink_fun(y, x, w_t_all, scale)
  431. def add_expand(
  432. self,
  433. y: torch.Tensor,
  434. x: torch.Tensor,
  435. w_t_all: torch.Tensor,
  436. add_input: bool = True,
  437. ):
  438. """
  439. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  440. GEMM of lora'b.
  441. When `is_prefill` is true, it indicates that it is currently the
  442. prefill stage, and the `expand_prefill` function should be called.
  443. Otherwise, it is the decode stage, and the expand_decode function
  444. should be called.
  445. """
  446. expand_fun: Callable = (self.expand_prefill
  447. if self.is_prefill else self.expand_decode)
  448. expand_fun(y, x, w_t_all, add_input)
  449. def add_expand_slice(self,
  450. y: torch.Tensor,
  451. x: torch.Tensor,
  452. w_t_all: torch.Tensor,
  453. y_offset: Optional[int],
  454. y_slice_size: Optional[int],
  455. add_input: bool = True):
  456. """
  457. Similar to `add_expand`
  458. """
  459. expand_slice_fun: Callable = (self.expand_slice_prefill
  460. if self.is_prefill else
  461. self.expand_slice_decode)
  462. expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
  463. def add_lora(self,
  464. y: torch.Tensor,
  465. x: torch.Tensor,
  466. wa_t_all: torch.Tensor,
  467. wb_t_all: torch.Tensor,
  468. scale: float,
  469. y_offset: Optional[int] = None,
  470. y_slice_size: Optional[int] = None,
  471. *,
  472. buffer: Optional[torch.Tensor] = None) -> None:
  473. """
  474. Semantics:
  475. y[i] += (
  476. x[i].unsqueeze(0)
  477. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  478. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  479. * scale
  480. ).squeeze(0)
  481. Args:
  482. y (torch.Tensor): Output tensor. Will be changed in-place.
  483. x (torch.Tensor): Input tensor
  484. wa_t_all (torch.Tensor): lora_a's weight
  485. wb_t_all (torch.Tensor): lora_b's weight
  486. scale (float): Scaling factor.
  487. y_offset (Optional[int], optional): Offset to apply to the starting
  488. column of y.
  489. y_slice_size (Optional[int], optional): Size of the y column slice..
  490. buffer (Optional[torch.Tensor], optional): Defaults to None.
  491. """
  492. y_org = y
  493. y = y.view(-1, y.shape[-1])
  494. x = x.view(-1, x.shape[-1])
  495. r = wb_t_all.size(-1)
  496. if buffer is None:
  497. # We set the buffer to be float32 by default ,refer to:
  498. # https://github.com/triton-lang/triton/issues/1387
  499. buffer = torch.zeros((x.size(0), r),
  500. dtype=torch.float32,
  501. device=x.device)
  502. self.add_shrink(buffer, x, wa_t_all, scale)
  503. if y_offset is None and y_slice_size is None:
  504. self.add_expand(y, buffer, wb_t_all, add_input=True)
  505. else:
  506. self.add_expand_slice(y,
  507. buffer,
  508. wb_t_all,
  509. y_offset,
  510. y_slice_size,
  511. add_input=True)
  512. y = y.view_as(y_org)
  513. def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
  514. lora_a_stacked: Tuple[torch.Tensor,
  515. torch.Tensor,
  516. torch.Tensor],
  517. lora_b_stacked: Tuple[torch.Tensor,
  518. torch.Tensor,
  519. torch.Tensor],
  520. scale: float,
  521. output_slices: Tuple[int, ...]) -> None:
  522. """
  523. Applies lora to each input. Similar to add_lora, This method is
  524. used for layers that are composed of multiple sublayers
  525. (slices) packed together.
  526. """
  527. y_org = y
  528. x = x.view(-1, x.shape[-1])
  529. y = y.view(-1, y.shape[-1])
  530. offset_left = 0
  531. # TODO fuse these kernels
  532. for slice_idx in range(len(output_slices)):
  533. self.add_lora(y, x, lora_a_stacked[slice_idx],
  534. lora_b_stacked[slice_idx], scale, offset_left,
  535. output_slices[slice_idx])
  536. offset_left += output_slices[slice_idx]
  537. y = y.view_as(y_org)
  538. def add_lora_logits(self,
  539. y: torch.Tensor,
  540. x: torch.Tensor,
  541. wa_t_all: torch.Tensor,
  542. wb_t_all: torch.Tensor,
  543. scale,
  544. *,
  545. buffer: Optional[torch.Tensor] = None) -> None:
  546. """
  547. LogitsProcessorWithLoRA always using bgmv
  548. """
  549. y_org = y
  550. y = y.view(-1, y.shape[-1])
  551. x = x.view(-1, x.shape[-1])
  552. r = wb_t_all.size(-1)
  553. if buffer is None:
  554. # We set the buffer to be float32 by default ,refer to:
  555. # https://github.com/triton-lang/triton/issues/1387
  556. buffer = torch.zeros((x.size(0), r),
  557. dtype=torch.float32,
  558. device=x.device)
  559. bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
  560. bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
  561. y = y.view_as(y_org)