punica.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  8. import torch
  9. from aphrodite.common.utils import is_xpu
  10. from aphrodite.triton_utils import HAS_TRITON
  11. if HAS_TRITON and not is_xpu():
  12. from aphrodite.lora.ops.bgmv_expand import bgmv_expand
  13. from aphrodite.lora.ops.bgmv_expand_slice import bgmv_expand_slice
  14. from aphrodite.lora.ops.bgmv_shrink import bgmv_shrink
  15. from aphrodite.lora.ops.sgmv_expand import sgmv_expand
  16. from aphrodite.lora.ops.sgmv_expand_slice import sgmv_expand_slice
  17. from aphrodite.lora.ops.sgmv_shrink import sgmv_shrink
  18. if TYPE_CHECKING:
  19. # avoid circuit import
  20. from aphrodite.lora.layers import LoRAMapping
  21. from aphrodite.lora.models import LongContextLoRAContext
  22. def compute_meta(
  23. token_lora_tensor: torch.Tensor
  24. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
  25. """
  26. Get the information required for the sgmv kernel. With the features:
  27. 1. If consecutive requests in the batch use the same LoRA, this function
  28. will combine them into a single request, improving sgmv kernel inference
  29. performance.
  30. 2. At the beginning of each prefill stage inference, recalculations are
  31. needed based on the input, but only once.
  32. """
  33. lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
  34. token_lora_tensor, return_counts=True)
  35. cum_result = torch.cumsum(seq_length_tensor, dim=0)
  36. b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
  37. b_seq_start_tensor[1:].copy_(cum_result[:-1])
  38. max_length = seq_length_tensor.max().item()
  39. batch_size = lora_indices_tensor.size(0)
  40. no_lora = False
  41. # -1 means no lora should be applied. Use `no_lora` to determine whether
  42. # the current step requires LoRA. If LoRA is not needed, the prefill stage
  43. # does not need to launch the triton kernel, which can improve performance
  44. if batch_size == 1 and lora_indices_tensor == -1:
  45. no_lora = True
  46. return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  47. batch_size, max_length, no_lora)
  48. # TODO see if this can be vectorized
  49. def convert_mapping(
  50. mapping: "LoRAMapping",
  51. lora_index_to_id: List[Optional[int]],
  52. max_loras: int,
  53. vocab_size: int,
  54. extra_vocab_size: int,
  55. long_lora_context: Optional["LongContextLoRAContext"] = None,
  56. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
  57. Optional[torch.Tensor], List[int]]:
  58. """Converts LoRAMapping to index tensors.
  59. Args:
  60. mapping: LoRAMapping mapping rows in a batch to LoRA ids.
  61. lora_index_to_id: List mapping LoRA ids to LoRA indices.
  62. max_loras: Maximum number of LoRAs.
  63. vocab_size: Model vocab size.
  64. extra_vocab_size: Extra vocab size each LoRA can have.
  65. long_lora_context: Passed if there are long context lora in a batch.
  66. Returns:
  67. A tuple of tensors:
  68. base_indices: Tensor of shape [batch_size] mapping batch rows to
  69. LoRA indices.
  70. sampler_indices: Tensor of shape [batch_size] mapping requests to
  71. LoRA indices for sampler. For generation, this will be the
  72. same as base_indicies. For prefill, this will map requests
  73. to LoRA indices.
  74. sampler_indices_padded: Tensor of shape [batch_size] mapping
  75. requests to LoRA indices for sampler with padding.
  76. Same as sampler_indicies, but -1 is replaced with
  77. max_loras.
  78. embeddings_indices: Tensor of shape [2, batch_size] mapping
  79. requests to embedding indices. First row is for embeddings
  80. added by the LoRAs, second row is for the LoRA.lora_a
  81. embeddings.
  82. long_lora_indices: Tensor of shape [batch_size] mapping
  83. requests to RoPE offsets and rot dims for long LoRAs.
  84. None if long context lora doesn't exist.
  85. indices_len: List of lengths of the above tensors. It contains
  86. (base_indices, sampler_indices, sampler_indices_padded,
  87. embeddings_indices, long_lora_indices).
  88. """
  89. index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
  90. embedding_indices = index_mapping_indices.copy()
  91. lora_indices = index_mapping_indices.copy()
  92. long_lora_offsets: Optional[torch.Tensor] = None
  93. if long_lora_context:
  94. long_lora_offsets = torch.zeros(len(index_mapping_indices),
  95. device="cuda",
  96. dtype=torch.long)
  97. prompt_mapping: List[int] = [
  98. lora_index_to_id.index(x) if x > 0 else -1
  99. for x in mapping.prompt_mapping
  100. ]
  101. lora_idx = None
  102. for i in range(len(index_mapping_indices)):
  103. # TODO index can be slow. optimize
  104. lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
  105. if index_mapping_indices[i] > 0 else -1)
  106. embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
  107. lora_indices[i] = lora_idx
  108. if long_lora_context:
  109. assert long_lora_offsets is not None
  110. lora_offset: int = long_lora_context.offsets_by_lora_id.get(
  111. index_mapping_indices[i], 0)
  112. long_lora_offsets[i] = lora_offset
  113. indices_list: List[Union[List[int], torch.Tensor]] = [
  114. index_mapping_indices,
  115. lora_indices,
  116. embedding_indices,
  117. ]
  118. if long_lora_context:
  119. assert long_lora_offsets is not None
  120. indices_list.append(long_lora_offsets)
  121. indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
  122. prompt_mapping_tensor = torch.tensor(prompt_mapping,
  123. device="cuda",
  124. dtype=torch.long)
  125. embeddings_indices = torch.stack([
  126. indices[2] * extra_vocab_size,
  127. indices[2] * (vocab_size + extra_vocab_size),
  128. ])
  129. embeddings_indices[embeddings_indices == -1] = max_loras - 1
  130. base_indices = indices[1]
  131. sampler_indices = prompt_mapping_tensor
  132. sampler_indices_padded = sampler_indices.clone()
  133. sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
  134. sampler_indices_padded = torch.arange(
  135. 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
  136. sampler_indices_padded * len(sampler_indices_padded))
  137. long_lora_indices = None
  138. long_lora_indices_len: Optional[int] = None
  139. if long_lora_context:
  140. long_lora_indices = indices[3]
  141. long_lora_indices_len = long_lora_indices.shape[-1]
  142. # Contain length of indices tensors. Used to index into each tensor.
  143. indices_len = [
  144. base_indices.shape[-1],
  145. sampler_indices.shape[-1],
  146. sampler_indices_padded.shape[-1],
  147. embeddings_indices.shape[-1],
  148. ]
  149. if long_lora_indices_len is not None:
  150. indices_len.append(long_lora_indices_len)
  151. else:
  152. # If long_lora doesn't exist,append None
  153. indices_len.append(None)
  154. return (
  155. base_indices,
  156. sampler_indices,
  157. sampler_indices_padded,
  158. embeddings_indices,
  159. long_lora_indices,
  160. indices_len,
  161. )
  162. class PunicaWrapper:
  163. """
  164. PunicaWrapper is designed to manage and provide metadata for the punica
  165. kernel. The main function is to maintain the state information for
  166. Multi-LoRA, and to provide the interface for the punica kernel.
  167. """
  168. def __init__(self, max_num_batched_tokens: int, max_batches: int,
  169. device: str):
  170. self._token_lora_indices = torch.empty(max_num_batched_tokens,
  171. dtype=torch.long,
  172. device=device)
  173. self._sampler_indices = torch.empty(max_num_batched_tokens,
  174. dtype=torch.long,
  175. device=device)
  176. self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
  177. dtype=torch.long,
  178. device=device)
  179. self._embeddings_indices = torch.empty(2,
  180. max_num_batched_tokens,
  181. dtype=torch.long,
  182. device=device)
  183. self._long_lora_indices = torch.empty(max_num_batched_tokens,
  184. dtype=torch.long,
  185. device=device)
  186. # 5 is the number of indicies tensors.
  187. # base_indices, sampler_indices, sampler_indices_padded,
  188. # embeddings_indices,long_lora_indices
  189. self.indices_len: List[Optional[int]] = [None] * 5
  190. # these attributes are the information required for sgmv kernel
  191. self._seq_start_locs = torch.empty(max_batches,
  192. dtype=torch.long,
  193. device=device)
  194. self._seq_lengths = torch.empty(max_batches,
  195. dtype=torch.long,
  196. device=device)
  197. self._lora_indices_per_batch = torch.empty(max_batches,
  198. dtype=torch.long,
  199. device=device)
  200. self.max_length: int = 0
  201. self.batch_size: int = -1
  202. self.is_prefill = False
  203. self.no_lora = False
  204. def update_metadata(
  205. self,
  206. mapping: "LoRAMapping",
  207. lora_index_to_id: List[Optional[int]],
  208. max_loras: int,
  209. vocab_size: int,
  210. extra_vocab_size: int,
  211. long_lora_context: Optional["LongContextLoRAContext"] = None,
  212. ):
  213. self._update_base_metadata(mapping, lora_index_to_id, max_loras,
  214. vocab_size, extra_vocab_size,
  215. long_lora_context)
  216. if mapping.is_prefill:
  217. # Update metadata required for prefill-related operators.
  218. self._update_prefill_metada(self.token_lora_indices)
  219. self.is_prefill = True
  220. else:
  221. self.is_prefill = False
  222. def _update_base_metadata(
  223. self,
  224. mapping: "LoRAMapping",
  225. lora_index_to_id: List[Optional[int]],
  226. max_loras: int,
  227. vocab_size: int,
  228. extra_vocab_size: int,
  229. long_lora_context: Optional["LongContextLoRAContext"] = None,
  230. ):
  231. (
  232. base_indices,
  233. sampler_indices,
  234. sampler_indices_padded,
  235. embeddings_indices,
  236. long_lora_offsets_tensor,
  237. indices_len,
  238. ) = convert_mapping(
  239. mapping,
  240. lora_index_to_id,
  241. max_loras,
  242. vocab_size,
  243. extra_vocab_size,
  244. long_lora_context,
  245. )
  246. self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
  247. self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
  248. self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
  249. sampler_indices_padded)
  250. self._embeddings_indices[:embeddings_indices.
  251. shape[0], :embeddings_indices.shape[1]].copy_(
  252. embeddings_indices)
  253. if long_lora_offsets_tensor is not None:
  254. self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
  255. long_lora_offsets_tensor)
  256. else:
  257. self._long_lora_indices.zero_()
  258. self.indices_len[:] = indices_len
  259. def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
  260. (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
  261. batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
  262. self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
  263. b_seq_start_tensor)
  264. self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
  265. self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
  266. lora_indices_tensor)
  267. self.batch_size = batch_size
  268. self.max_length = max_length
  269. self.no_lora = no_lora
  270. @property
  271. def prefill_metadata(
  272. self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
  273. """
  274. This property provides a convenient way to access the necessary
  275. metadata for prefill-related kernel computations.
  276. 1. seq_start_locs: Tensor of sequence start positions
  277. 2. seq_lengths: Tensor of sequence lengths
  278. 3. lora_indices_per_batch: Tensor of lora indices, and an index of
  279. -1 means no lora should be applied.
  280. 4. batch_size: batch size after clustering identical lora indices
  281. 5. max_length: The maximum sequence length in the batch
  282. """
  283. return (self._seq_start_locs[:self.batch_size],
  284. self._seq_lengths[:self.batch_size],
  285. self._lora_indices_per_batch[:self.batch_size],
  286. self.batch_size, self.max_length)
  287. @property
  288. def token_lora_indices(self) -> torch.Tensor:
  289. """
  290. This property provides the lora indices corresponding to each token
  291. in the batch. An index of -1 means no lora should be applied.
  292. """
  293. token_lora_len = self.indices_len[0]
  294. return self._token_lora_indices[:token_lora_len]
  295. @property
  296. def sampler_indices(self) -> torch.Tensor:
  297. """
  298. This property is used to access the lora indices specifically for
  299. LogitsProcessorWithLoRA
  300. """
  301. sampler_indices_len = self.indices_len[1]
  302. return self._sampler_indices[:sampler_indices_len]
  303. @property
  304. def sampler_indices_padded(self) -> torch.Tensor:
  305. """
  306. This property provides access to padded sampler indices
  307. """
  308. indices_padded_len = self.indices_len[2]
  309. return self._sampler_indices_padded[:indices_padded_len]
  310. @property
  311. def embeddings_indices(self) -> torch.Tensor:
  312. """
  313. This property provides access to the indices used for lora embeddings,
  314. specifically for VocabParallelEmbeddingWithLoRA
  315. """
  316. embeddings_indices_len = self.indices_len[3]
  317. return self._embeddings_indices[:, :embeddings_indices_len]
  318. @property
  319. def long_lora_indices(self) -> torch.Tensor:
  320. """
  321. This property provides access to the indices used for long context
  322. lora, specifically for LinearScalingRotaryEmbeddingWithLora
  323. """
  324. long_lora_len = self.indices_len[4]
  325. return self._long_lora_indices[:long_lora_len]
  326. def shrink_prefill(
  327. self,
  328. y: torch.Tensor,
  329. x: torch.Tensor,
  330. w_t_all: torch.Tensor,
  331. scale: float,
  332. ):
  333. #No LoRA request, so return directly
  334. if self.no_lora:
  335. return
  336. sgmv_shrink(
  337. x,
  338. w_t_all,
  339. y,
  340. *self.prefill_metadata,
  341. scale,
  342. )
  343. def shrink_decode(
  344. self,
  345. y: torch.Tensor,
  346. x: torch.Tensor,
  347. w_t_all: torch.Tensor,
  348. scale: float,
  349. ):
  350. bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
  351. def expand_prefill(
  352. self,
  353. y: torch.Tensor,
  354. x: torch.Tensor,
  355. w_t_all: torch.Tensor,
  356. add_input: bool,
  357. ):
  358. #No LoRA request, so return directly
  359. if self.no_lora:
  360. return
  361. sgmv_expand(
  362. x,
  363. w_t_all,
  364. y,
  365. *self.prefill_metadata,
  366. add_input,
  367. )
  368. def expand_decode(
  369. self,
  370. y: torch.Tensor,
  371. x: torch.Tensor,
  372. w_t_all: torch.Tensor,
  373. add_input: bool,
  374. ):
  375. bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
  376. def expand_slice_prefill(
  377. self,
  378. y: torch.Tensor,
  379. x: torch.Tensor,
  380. w_t_all: torch.Tensor,
  381. y_offset: Optional[int],
  382. y_slice_size: Optional[int],
  383. add_input: bool,
  384. ):
  385. #No LoRA request, so return directly
  386. if self.no_lora:
  387. return
  388. sgmv_expand_slice(
  389. x,
  390. w_t_all,
  391. y,
  392. *self.prefill_metadata,
  393. y_offset,
  394. y_slice_size,
  395. add_input,
  396. )
  397. def expand_slice_decode(
  398. self,
  399. y: torch.Tensor,
  400. x: torch.Tensor,
  401. w_t_all: torch.Tensor,
  402. y_offset: Optional[int],
  403. y_slice_size: Optional[int],
  404. add_input: bool,
  405. ):
  406. bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
  407. y_slice_size, add_input)
  408. def add_shrink(
  409. self,
  410. y: torch.Tensor,
  411. x: torch.Tensor,
  412. w_t_all: torch.Tensor,
  413. scale: float,
  414. ):
  415. """
  416. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  417. GEMM of lora'a.
  418. When `is_prefill is` true, it indicates that it is currently the
  419. prefill stage, and the `shrink_prefill` function should be called.
  420. Otherwise, it is the decode stage, and the shrink_decode function
  421. should be called.
  422. """
  423. shrink_fun: Callable = (self.shrink_prefill
  424. if self.is_prefill else self.shrink_decode)
  425. shrink_fun(y, x, w_t_all, scale)
  426. def add_expand(
  427. self,
  428. y: torch.Tensor,
  429. x: torch.Tensor,
  430. w_t_all: torch.Tensor,
  431. add_input: bool = True,
  432. ):
  433. """
  434. Perform the ` y+=x@w_t_all` computation, which is suitable for the
  435. GEMM of lora'b.
  436. When `is_prefill` is true, it indicates that it is currently the
  437. prefill stage, and the `expand_prefill` function should be called.
  438. Otherwise, it is the decode stage, and the expand_decode function
  439. should be called.
  440. """
  441. expand_fun: Callable = (self.expand_prefill
  442. if self.is_prefill else self.expand_decode)
  443. expand_fun(y, x, w_t_all, add_input)
  444. def add_expand_slice(self,
  445. y: torch.Tensor,
  446. x: torch.Tensor,
  447. w_t_all: torch.Tensor,
  448. y_offset: Optional[int],
  449. y_slice_size: Optional[int],
  450. add_input: bool = True):
  451. """
  452. Similar to `add_expand`
  453. """
  454. expand_slice_fun: Callable = (self.expand_slice_prefill
  455. if self.is_prefill else
  456. self.expand_slice_decode)
  457. expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
  458. def add_lora(self,
  459. y: torch.Tensor,
  460. x: torch.Tensor,
  461. wa_t_all: torch.Tensor,
  462. wb_t_all: torch.Tensor,
  463. scale: float,
  464. y_offset: Optional[int] = None,
  465. y_slice_size: Optional[int] = None,
  466. *,
  467. buffer: Optional[torch.Tensor] = None) -> None:
  468. """
  469. Semantics:
  470. y[i] += (
  471. x[i].unsqueeze(0)
  472. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  473. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  474. * scale
  475. ).squeeze(0)
  476. Args:
  477. y (torch.Tensor): Output tensor. Will be changed in-place.
  478. x (torch.Tensor): Input tensor
  479. wa_t_all (torch.Tensor): lora_a's weight
  480. wb_t_all (torch.Tensor): lora_b's weight
  481. scale (float): Scaling factor.
  482. y_offset (Optional[int], optional): Offset to apply to the starting
  483. column of y.
  484. y_slice_size (Optional[int], optional): Size of the y column slice..
  485. buffer (Optional[torch.Tensor], optional): Defaults to None.
  486. """
  487. y_org = y
  488. y = y.view(-1, y.shape[-1])
  489. x = x.view(-1, x.shape[-1])
  490. r = wb_t_all.size(-1)
  491. if buffer is None:
  492. # We set the buffer to be float32 by default ,refer to:
  493. # https://github.com/triton-lang/triton/issues/1387
  494. buffer = torch.zeros((x.size(0), r),
  495. dtype=torch.float32,
  496. device=x.device)
  497. self.add_shrink(buffer, x, wa_t_all, scale)
  498. if y_offset is None and y_slice_size is None:
  499. self.add_expand(y, buffer, wb_t_all, add_input=True)
  500. else:
  501. self.add_expand_slice(y,
  502. buffer,
  503. wb_t_all,
  504. y_offset,
  505. y_slice_size,
  506. add_input=True)
  507. y = y.view_as(y_org)
  508. def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
  509. lora_a_stacked: Tuple[torch.Tensor,
  510. torch.Tensor,
  511. torch.Tensor],
  512. lora_b_stacked: Tuple[torch.Tensor,
  513. torch.Tensor,
  514. torch.Tensor],
  515. scale: float,
  516. output_slices: Tuple[int, ...]) -> None:
  517. """
  518. Applies lora to each input. Similar to add_lora, This method is
  519. used for layers that are composed of multiple sublayers
  520. (slices) packed together.
  521. """
  522. y_org = y
  523. x = x.view(-1, x.shape[-1])
  524. y = y.view(-1, y.shape[-1])
  525. offset_left = 0
  526. # TODO fuse these kernels
  527. for slice_idx in range(len(output_slices)):
  528. self.add_lora(y, x, lora_a_stacked[slice_idx],
  529. lora_b_stacked[slice_idx], scale, offset_left,
  530. output_slices[slice_idx])
  531. offset_left += output_slices[slice_idx]
  532. y = y.view_as(y_org)
  533. def add_lora_logits(self,
  534. y: torch.Tensor,
  535. x: torch.Tensor,
  536. wa_t_all: torch.Tensor,
  537. wb_t_all: torch.Tensor,
  538. scale,
  539. *,
  540. buffer: Optional[torch.Tensor] = None) -> None:
  541. """
  542. LogitsProcessorWithLoRA always using bgmv
  543. """
  544. y_org = y
  545. y = y.view(-1, y.shape[-1])
  546. x = x.view(-1, x.shape[-1])
  547. r = wb_t_all.size(-1)
  548. if buffer is None:
  549. # We set the buffer to be float32 by default ,refer to:
  550. # https://github.com/triton-lang/triton/issues/1387
  551. buffer = torch.zeros((x.size(0), r),
  552. dtype=torch.float32,
  553. device=x.device)
  554. bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
  555. bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
  556. y = y.view_as(y_org)