flash_attn.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
  6. from aphrodite import _custom_ops as ops
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionMetadataBuilder,
  11. AttentionType)
  12. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  13. compute_slot_mapping,
  14. compute_slot_mapping_start_idx,
  15. is_block_tables_empty)
  16. from aphrodite.common.sequence import SequenceGroupMetadata
  17. from aphrodite.common.utils import make_tensor_with_pad
  18. if TYPE_CHECKING:
  19. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  20. ModelInputForGPUBuilder)
  21. class FlashAttentionBackend(AttentionBackend):
  22. @staticmethod
  23. def get_supported_head_sizes() -> List[int]:
  24. return [32, 64, 96, 128, 160, 192, 224, 256]
  25. @staticmethod
  26. def get_name() -> str:
  27. return "flash-attn"
  28. @staticmethod
  29. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  30. return FlashAttentionImpl
  31. @staticmethod
  32. def get_metadata_cls() -> Type["AttentionMetadata"]:
  33. return FlashAttentionMetadata
  34. @staticmethod
  35. def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
  36. return FlashAttentionMetadataBuilder
  37. @staticmethod
  38. def get_kv_cache_shape(
  39. num_blocks: int,
  40. block_size: int,
  41. num_kv_heads: int,
  42. head_size: int,
  43. ) -> Tuple[int, ...]:
  44. if block_size % 16 != 0:
  45. raise ValueError("Block size must be a multiple of 16.")
  46. return (2, num_blocks, block_size, num_kv_heads, head_size)
  47. @staticmethod
  48. def swap_blocks(
  49. src_kv_cache: torch.Tensor,
  50. dst_kv_cache: torch.Tensor,
  51. src_to_dst: torch.Tensor,
  52. ) -> None:
  53. src_key_cache = src_kv_cache[0]
  54. dst_key_cache = dst_kv_cache[0]
  55. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  56. src_value_cache = src_kv_cache[1]
  57. dst_value_cache = dst_kv_cache[1]
  58. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  59. @staticmethod
  60. def copy_blocks(
  61. kv_caches: List[torch.Tensor],
  62. src_to_dists: torch.Tensor,
  63. ) -> None:
  64. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  65. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  66. ops.copy_blocks(key_caches, value_caches, src_to_dists)
  67. @dataclass
  68. class FlashAttentionMetadata(AttentionMetadata):
  69. """Metadata for FlashAttentionBackend.
  70. NOTE: Any python object stored here is not updated when it is
  71. cuda-graph replayed. If you have values that need to be changed
  72. dynamically, it should be stored in tensor. The tensor has to be
  73. updated from `CUDAGraphRunner.forward` API.
  74. """
  75. # (batch_size,). The sequence length per sequence. Sequence length means
  76. # the computed tokens + new tokens None if it is a decoding.
  77. seq_lens: Optional[List[int]]
  78. # seq_lens stored as a tensor.
  79. seq_lens_tensor: Optional[torch.Tensor]
  80. # NOTE: Definition of context_len, query_len, and seq_len.
  81. # |---------- N-1 iteration --------|
  82. # |---------------- N iteration ---------------------|
  83. # |- tokenA -|......................|-- newTokens ---|
  84. # |---------- context_len ----------|
  85. # |-------------------- seq_len ----------------------|
  86. # |-- query_len ---|
  87. # Maximum query length in the batch. None for decoding.
  88. max_query_len: Optional[int]
  89. # Maximum sequence length among prefill batch. 0 if there are decoding
  90. # requests only.
  91. max_prefill_seq_len: int
  92. # Maximum sequence length among decode batch. 0 if there are prefill
  93. # requests only.
  94. max_decode_seq_len: int
  95. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  96. # the batch, used to index into subquery. E.g., if the subquery length
  97. # is [4, 6], it is [0, 4, 10].
  98. query_start_loc: Optional[torch.Tensor]
  99. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  100. # the batch, used to index into sequence. E.g., if the sequence length is
  101. # [4, 6], it is [0, 4, 10].
  102. seq_start_loc: Optional[torch.Tensor]
  103. # (batch_size,) A tensor of context lengths (tokens that are computed
  104. # so far).
  105. context_lens_tensor: Optional[torch.Tensor]
  106. # (batch_size, max_blocks_per_seq).
  107. # Block addresses per sequence. (Seq id -> list of physical block)
  108. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  109. # in the kv cache. Each block can contain up to block_size tokens.
  110. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  111. # captured.
  112. block_tables: Optional[torch.Tensor]
  113. # Whether or not if cuda graph is enabled.
  114. # Cuda-graph is currently enabled for decoding only.
  115. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  116. use_cuda_graph: bool
  117. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  118. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  119. @property
  120. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  121. if self.num_prefills == 0:
  122. return None
  123. if self._cached_prefill_metadata is not None:
  124. return self._cached_prefill_metadata
  125. assert self.seq_lens is not None
  126. assert self.seq_lens_tensor is not None
  127. assert self.query_start_loc is not None
  128. assert self.context_lens_tensor is not None
  129. assert self.block_tables is not None
  130. assert self.seq_start_loc is not None
  131. self._cached_prefill_metadata = FlashAttentionMetadata(
  132. num_prefills=self.num_prefills,
  133. num_prefill_tokens=self.num_prefill_tokens,
  134. num_decode_tokens=0,
  135. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  136. seq_lens=self.seq_lens[:self.num_prefills],
  137. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  138. max_query_len=self.max_query_len,
  139. max_prefill_seq_len=self.max_prefill_seq_len,
  140. max_decode_seq_len=0,
  141. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  142. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  143. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  144. block_tables=self.block_tables[:self.num_prefills],
  145. use_cuda_graph=False,
  146. )
  147. return self._cached_prefill_metadata
  148. @property
  149. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  150. if self.num_decode_tokens == 0:
  151. return None
  152. if self._cached_decode_metadata is not None:
  153. return self._cached_decode_metadata
  154. assert self.block_tables is not None
  155. assert self.seq_lens_tensor is not None
  156. self._cached_decode_metadata = FlashAttentionMetadata(
  157. num_prefills=0,
  158. num_prefill_tokens=0,
  159. num_decode_tokens=self.num_decode_tokens,
  160. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  161. seq_lens=None,
  162. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  163. max_query_len=None,
  164. max_prefill_seq_len=0,
  165. max_decode_seq_len=self.max_decode_seq_len,
  166. query_start_loc=None,
  167. seq_start_loc=None,
  168. context_lens_tensor=None,
  169. block_tables=self.block_tables[self.num_prefills:],
  170. use_cuda_graph=self.use_cuda_graph,
  171. )
  172. return self._cached_decode_metadata
  173. class FlashAttentionMetadataBuilder(
  174. AttentionMetadataBuilder[FlashAttentionMetadata]):
  175. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  176. self.slot_mapping: List[int] = []
  177. self.prefill_seq_lens: List[int] = []
  178. self.context_lens: List[int] = []
  179. self.block_tables: List[List[int]] = []
  180. self.curr_seq_lens: List[int] = []
  181. self.num_prefills = 0
  182. self.num_prefill_tokens = 0
  183. self.num_decode_tokens = 0
  184. self.sliding_window = input_builder.sliding_window
  185. self.block_size = input_builder.block_size
  186. self.use_v2_block_manager = (
  187. input_builder.scheduler_config.use_v2_block_manager)
  188. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
  189. token_lens: List[int], seq_lens: List[int],
  190. curr_seq_lens: List[int], query_lens: List[int],
  191. context_lens: List[int],
  192. curr_sliding_window_blocks: List[int],
  193. prefix_cache_hit: bool, chunked_prefill_enabled: bool):
  194. """Add a sequence group to the metadata. Specifically update/append
  195. 1. context length.
  196. 2. block table.
  197. 3. slot mapping.
  198. """
  199. is_prompt = seq_group_metadata.is_prompt
  200. block_tables = seq_group_metadata.block_tables
  201. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  202. curr_sliding_window_block) in zip(
  203. seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
  204. curr_seq_lens, query_lens, context_lens,
  205. curr_sliding_window_blocks):
  206. self.context_lens.append(context_len)
  207. if is_prompt:
  208. self.num_prefills += 1
  209. self.num_prefill_tokens += token_len
  210. self.prefill_seq_lens.append(seq_len)
  211. else:
  212. assert query_len == 1, (
  213. "seq_len: {}, context_len: {}, query_len: {}".format(
  214. seq_len, context_len, query_len))
  215. self.num_decode_tokens += query_len
  216. self.curr_seq_lens.append(curr_seq_len)
  217. # Compute block table.
  218. # TODO: Combine chunked prefill and prefix caching by
  219. # only allowing multiple of block_size chunk size.
  220. # NOTE: This only works for oooooooxxx style attention.
  221. block_table = []
  222. if prefix_cache_hit:
  223. # NOTE: For flash-attn, the block table should
  224. # include the entries for the incoming prefill tokens.
  225. block_table = block_tables[seq_id]
  226. elif ((chunked_prefill_enabled or not is_prompt)
  227. and block_tables is not None):
  228. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  229. self.block_tables.append(block_table)
  230. # Compute slot mapping.
  231. is_profile_run = is_block_tables_empty(block_tables)
  232. start_idx = compute_slot_mapping_start_idx(
  233. is_prompt, query_len, context_len, self.sliding_window,
  234. self.use_v2_block_manager)
  235. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  236. seq_len, context_len, start_idx,
  237. self.block_size,
  238. seq_group_metadata.block_tables)
  239. def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
  240. cuda_graph_pad_size: int, batch_size: int):
  241. """Build attention metadata with on-device tensors."""
  242. device = runner.device
  243. use_captured_graph = cuda_graph_pad_size != -1
  244. logits_soft_cap = getattr(runner.model_config.hf_config,
  245. "attn_logit_softcapping", None)
  246. if logits_soft_cap is not None:
  247. raise ValueError(
  248. "Please use Flashinfer backend for models with logits_soft_cap"
  249. " (i.e., Gemma-2). Otherwise, the output might be wrong."
  250. " Set Flashinfer backend by "
  251. "export VLLM_ATTENTION_BACKEND=FLASHINFER.")
  252. max_query_len = max(query_lens)
  253. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  254. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  255. num_decode_tokens = self.num_decode_tokens
  256. if use_captured_graph:
  257. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  258. self.block_tables.extend([] * cuda_graph_pad_size)
  259. num_decode_tokens = batch_size + cuda_graph_pad_size
  260. # The shape of graph_block_tables is
  261. # [max batch size, max context len // block size].
  262. input_block_tables = runner.graph_block_tables[:batch_size]
  263. for i, block_table in enumerate(self.block_tables):
  264. if block_table:
  265. input_block_tables[i, :len(block_table)] = block_table
  266. block_tables = torch.tensor(input_block_tables, device=device)
  267. else:
  268. max_block_table_len = max(
  269. len(block_table) for block_table in self.block_tables)
  270. block_tables = make_tensor_with_pad(
  271. self.block_tables,
  272. max_len=max_block_table_len,
  273. pad=0,
  274. dtype=torch.int,
  275. device=device,
  276. )
  277. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  278. context_lens_tensor = torch.tensor(self.context_lens,
  279. dtype=torch.int,
  280. device=device)
  281. seq_lens_tensor = torch.tensor(seq_lens,
  282. dtype=torch.int,
  283. device=device)
  284. query_lens_tensor = torch.tensor(query_lens,
  285. dtype=torch.long,
  286. device=device)
  287. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  288. dtype=torch.int32,
  289. device=device)
  290. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  291. dtype=torch.int32,
  292. device=device)
  293. torch.cumsum(seq_lens_tensor,
  294. dim=0,
  295. dtype=seq_start_loc.dtype,
  296. out=seq_start_loc[1:])
  297. torch.cumsum(query_lens_tensor,
  298. dim=0,
  299. dtype=query_start_loc.dtype,
  300. out=query_start_loc[1:])
  301. slot_mapping_tensor = torch.tensor(self.slot_mapping,
  302. dtype=torch.long,
  303. device=device)
  304. return FlashAttentionMetadata(
  305. num_prefills=self.num_prefills,
  306. slot_mapping=slot_mapping_tensor,
  307. num_prefill_tokens=self.num_prefill_tokens,
  308. num_decode_tokens=num_decode_tokens,
  309. seq_lens=seq_lens,
  310. seq_lens_tensor=seq_lens_tensor,
  311. max_query_len=max_query_len,
  312. max_prefill_seq_len=max_prefill_seq_len,
  313. max_decode_seq_len=max_decode_seq_len,
  314. query_start_loc=query_start_loc,
  315. seq_start_loc=seq_start_loc,
  316. context_lens_tensor=context_lens_tensor,
  317. block_tables=block_tables,
  318. use_cuda_graph=use_captured_graph,
  319. )
  320. class FlashAttentionImpl(AttentionImpl):
  321. """
  322. If the input tensors contain prompt tokens, the layout is as follows:
  323. |<--------------- num_prefill_tokens ----------------->|
  324. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  325. Otherwise, the layout is as follows:
  326. |<----------------- num_decode_tokens ------------------>|
  327. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  328. Generation tokens can contain padding when cuda-graph is used.
  329. Currently, prompt tokens don't contain any padding.
  330. The prompts might have different lengths, while the generation tokens
  331. always have length 1.
  332. If chunked prefill is enabled, prefill tokens and decode tokens can be
  333. batched together in a flattened 1D query.
  334. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  335. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  336. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  337. padding between prefill and decode tokens.
  338. """
  339. def __init__(
  340. self,
  341. num_heads: int,
  342. head_size: int,
  343. scale: float,
  344. num_kv_heads: int,
  345. alibi_slopes: Optional[List[float]],
  346. sliding_window: Optional[int],
  347. kv_cache_dtype: str,
  348. blocksparse_params: Optional[Dict[str, Any]] = None,
  349. ) -> None:
  350. assert blocksparse_params is None, ValueError(
  351. "FlashAttention does not support block-sparse attention.")
  352. self.num_heads = num_heads
  353. self.head_size = head_size
  354. self.scale = float(scale)
  355. self.num_kv_heads = num_kv_heads
  356. if alibi_slopes is not None:
  357. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  358. self.alibi_slopes = alibi_slopes
  359. self.sliding_window = ((sliding_window, sliding_window)
  360. if sliding_window is not None else (-1, -1))
  361. self.kv_cache_dtype = kv_cache_dtype
  362. assert self.num_heads % self.num_kv_heads == 0
  363. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  364. if sliding_window is not None:
  365. # NOTE: flash-attn's sliding window does not work with
  366. # paged KV cache.
  367. raise ValueError(
  368. "Sliding window is not supported in FlashAttention.")
  369. support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
  370. if head_size not in support_head_sizes:
  371. raise ValueError(
  372. f"Head size {head_size} is not supported by FlashAttention. "
  373. f"Supported head sizes are: {support_head_sizes}.")
  374. def forward(
  375. self,
  376. query: torch.Tensor,
  377. key: torch.Tensor,
  378. value: torch.Tensor,
  379. kv_cache: torch.Tensor,
  380. attn_metadata: FlashAttentionMetadata,
  381. k_scale: float = 1.0,
  382. v_scale: float = 1.0,
  383. attn_type: AttentionType = AttentionType.DECODER,
  384. ) -> torch.Tensor:
  385. """Forward pass with FlashAttention.
  386. Args:
  387. query: shape = [num_tokens, num_heads * head_size]
  388. key: shape = [num_tokens, num_kv_heads * head_size]
  389. value: shape = [num_tokens, num_kv_heads * head_size]
  390. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  391. attn_metadata: Metadata for attention.
  392. Returns:
  393. shape = [num_tokens, num_heads * head_size]
  394. """
  395. if attn_type != AttentionType.DECODER:
  396. raise NotImplementedError("Encoder self-attention and "
  397. "encoder/decoder cross-attention "
  398. "are not implemented for "
  399. "FlashAttentionImpl")
  400. # NOTE: FlashAttention does not support FP8 KV cache.
  401. assert k_scale == 1.0 and v_scale == 1.0, (
  402. "key/v_scale is not supported in FlashAttention.")
  403. num_tokens, hidden_size = query.shape
  404. # Reshape the query, key, and value tensors.
  405. query = query.view(-1, self.num_heads, self.head_size)
  406. key = key.view(-1, self.num_kv_heads, self.head_size)
  407. value = value.view(-1, self.num_kv_heads, self.head_size)
  408. if kv_cache is not None:
  409. key_cache = kv_cache[0]
  410. value_cache = kv_cache[1]
  411. # Reshape the input keys and values and store them in the cache.
  412. # If kv_cache is not provided, the new key and value tensors are
  413. # not cached. This happens during the initial memory profiling run.
  414. ops.reshape_and_cache_flash(
  415. key,
  416. value,
  417. key_cache,
  418. value_cache,
  419. attn_metadata.slot_mapping.flatten(),
  420. self.kv_cache_dtype,
  421. )
  422. num_prefill_tokens = attn_metadata.num_prefill_tokens
  423. num_decode_tokens = attn_metadata.num_decode_tokens
  424. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  425. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  426. output = torch.empty_like(query)
  427. # Query for decode. KV is not needed because it is already cached.
  428. decode_query = query[num_prefill_tokens:]
  429. # QKV for prefill.
  430. query = query[:num_prefill_tokens]
  431. key = key[:num_prefill_tokens]
  432. value = value[:num_prefill_tokens]
  433. assert query.shape[0] == num_prefill_tokens
  434. assert decode_query.shape[0] == num_decode_tokens
  435. if prefill_meta := attn_metadata.prefill_metadata:
  436. # Prompt run.
  437. if (kv_cache is None or prefill_meta.block_tables is None
  438. or prefill_meta.block_tables.numel() == 0):
  439. # normal attention
  440. # When block_tables are not filled, it means q and k are the
  441. # prompt, and they have the same length.
  442. out = flash_attn_varlen_func(
  443. q=query,
  444. k=key,
  445. v=value,
  446. cu_seqlens_q=prefill_meta.seq_start_loc,
  447. cu_seqlens_k=prefill_meta.seq_start_loc,
  448. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  449. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  450. softmax_scale=self.scale,
  451. causal=True,
  452. window_size=self.sliding_window,
  453. alibi_slopes=self.alibi_slopes,
  454. )
  455. assert output[:num_prefill_tokens].shape == out.shape
  456. output[:num_prefill_tokens] = out
  457. else:
  458. # prefix-enabled attention
  459. assert prefill_meta.seq_lens is not None
  460. max_seq_len = max(prefill_meta.seq_lens)
  461. output[:num_prefill_tokens] = flash_attn_varlen_func(
  462. q=query,
  463. k=key_cache,
  464. v=value_cache,
  465. cu_seqlens_q=prefill_meta.query_start_loc,
  466. max_seqlen_q=prefill_meta.max_query_len,
  467. cu_seqlens_k=prefill_meta.seq_start_loc,
  468. max_seqlen_k=max_seq_len,
  469. softmax_scale=self.scale,
  470. causal=True,
  471. alibi_slopes=self.alibi_slopes,
  472. block_table=prefill_meta.block_tables,
  473. )
  474. if decode_meta := attn_metadata.decode_metadata:
  475. # Decoding run.
  476. output[num_prefill_tokens:] = flash_attn_with_kvcache(
  477. decode_query.unsqueeze(1),
  478. key_cache,
  479. value_cache,
  480. block_table=decode_meta.block_tables,
  481. cache_seqlens=decode_meta.seq_lens_tensor,
  482. softmax_scale=self.scale,
  483. causal=True,
  484. alibi_slopes=self.alibi_slopes,
  485. ).squeeze(1)
  486. # Reshape the output tensor.
  487. return output.view(num_tokens, hidden_size)