generation.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740
  1. # Copyright (c) 2023, Tri Dao.
  2. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
  3. import gc
  4. import time
  5. from collections import namedtuple
  6. from dataclasses import dataclass, field
  7. from functools import partial
  8. from typing import Callable, Optional, Sequence, Union
  9. import torch
  10. import torch.nn.functional as F
  11. from einops import rearrange, repeat
  12. from torch import Tensor
  13. from torch.profiler import ProfilerActivity, profile, record_function
  14. try:
  15. from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
  16. except ImportError:
  17. GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"])
  18. SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"])
  19. @dataclass
  20. class InferenceParams:
  21. """Inference parameters that are passed to the main model in order
  22. to efficienly calculate and store the context during inference."""
  23. max_seqlen: int
  24. max_batch_size: int
  25. seqlen_offset: int = 0
  26. batch_size_offset: int = 0
  27. key_value_memory_dict: dict = field(default_factory=dict)
  28. lengths_per_sample: Optional[Tensor] = None
  29. def reset(self, max_seqlen, max_batch_size):
  30. self.max_seqlen = max_seqlen
  31. self.max_batch_size = max_batch_size
  32. self.seqlen_offset = 0
  33. if self.lengths_per_sample is not None:
  34. self.lengths_per_sample.zero_()
  35. # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
  36. # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
  37. def modify_logits_for_top_k_filtering(logits, top_k):
  38. """Set the logits for none top-k values to -inf. Done in-place."""
  39. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
  40. logits.masked_fill_(indices_to_remove, float("-Inf"))
  41. # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
  42. # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
  43. def modify_logits_for_top_p_filtering(logits, top_p):
  44. """Set the logits for none top-p values to -inf. Done in-place."""
  45. if top_p <= 0.0 or top_p >= 1.0:
  46. return
  47. # First sort and calculate cumulative sum of probabilities.
  48. sorted_logits, sorted_indices = torch.sort(logits, descending=False)
  49. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  50. # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
  51. sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
  52. # scatter sorted tensors to original indexing
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. 1, sorted_indices, sorted_indices_to_remove
  55. )
  56. logits.masked_fill_(indices_to_remove, float("-inf"))
  57. def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
  58. """Sample from top-k logits.
  59. Arguments:
  60. logits: Tensor of shape (batch_size, vocab_size)
  61. """
  62. if top_k == 1: # Short-circuit for greedy decoding
  63. return logits.argmax(dim=-1)
  64. else:
  65. if top_p > 0.0:
  66. assert top_p <= 1.0, "top-p should be in (0, 1]."
  67. if top_k > 0:
  68. top_k = min(top_k, logits.size(-1)) # Safety check
  69. logits_top, indices = torch.topk(logits, top_k, dim=-1)
  70. if temperature != 1.0:
  71. logits_top /= temperature
  72. modify_logits_for_top_p_filtering(logits_top, top_p)
  73. return indices[
  74. torch.arange(indices.shape[0], device=indices.device),
  75. torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
  76. ]
  77. else:
  78. # Clone so that when we modify for top_p we don't change the original logits
  79. logits_top = logits / temperature if temperature != 1.0 else logits.clone()
  80. modify_logits_for_top_p_filtering(logits_top, top_p)
  81. return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
  82. dim=-1
  83. )
  84. @torch.inference_mode()
  85. def decode(
  86. input_ids,
  87. model,
  88. max_length,
  89. top_k=1,
  90. top_p=0.0,
  91. temperature=1.0,
  92. eos_token_id=None,
  93. teacher_outputs=None,
  94. vocab_size=None,
  95. tensor_parallel=1,
  96. cg=False,
  97. enable_timing=False,
  98. ):
  99. """Decoding, either greedy or with top-k or top-p sampling.
  100. If top-k = 0, don't limit the number of candidates (pure sampling).
  101. Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
  102. then top-p.
  103. We assume that all sequences in the same batch have the same length.
  104. Arguments:
  105. input_ids: (batch, seq_len)
  106. max_length: int
  107. teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
  108. logits, the next token is taken from the teacher_outputs. Useful for testing.
  109. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
  110. sequences: (batch, max_length)
  111. scores: tuples of (batch, vocab_size)
  112. """
  113. batch_size, seqlen_og = input_ids.shape
  114. teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
  115. if cg:
  116. if not hasattr(model, "_decoding_cache"):
  117. model._decoding_cache = None
  118. model._decoding_cache = update_graph_cache(
  119. model,
  120. model._decoding_cache,
  121. batch_size,
  122. seqlen_og,
  123. max_length,
  124. tensor_parallel=tensor_parallel,
  125. )
  126. inference_params = model._decoding_cache.inference_params
  127. inference_params.reset(max_length, batch_size)
  128. else:
  129. inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
  130. def get_logits(input_ids, inference_params):
  131. decoding = inference_params.seqlen_offset > 0
  132. if decoding:
  133. position_ids = torch.full(
  134. (batch_size, 1),
  135. inference_params.seqlen_offset,
  136. dtype=torch.long,
  137. device=input_ids.device,
  138. )
  139. else:
  140. position_ids = None
  141. if not cg or not decoding:
  142. logits = model(
  143. input_ids,
  144. position_ids=position_ids,
  145. inference_params=inference_params,
  146. num_last_tokens=1,
  147. ).logits.squeeze(dim=1)
  148. else:
  149. logits = model._decoding_cache.run(
  150. input_ids, position_ids, inference_params.seqlen_offset
  151. ).squeeze(dim=1)
  152. return logits[..., :vocab_size] if vocab_size is not None else logits
  153. def sample_tokens(logits, inference_params):
  154. if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
  155. token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
  156. else:
  157. token = teacher_outputs[:, inference_params.seqlen_offset]
  158. # return rearrange(token, "b -> b 1")
  159. return token.unsqueeze(1)
  160. def should_stop(current_token, inference_params):
  161. if inference_params.seqlen_offset == 0:
  162. return False
  163. if eos_token_id is not None and (current_token == eos_token_id).all():
  164. return True
  165. if inference_params.seqlen_offset >= max_length - 1:
  166. return True
  167. return False
  168. start = torch.cuda.Event(enable_timing=enable_timing)
  169. end = torch.cuda.Event(enable_timing=enable_timing)
  170. if enable_timing:
  171. if tensor_parallel > 1:
  172. torch.distributed.barrier()
  173. start.record()
  174. scores, sequences = [], [input_ids]
  175. while not should_stop(sequences[-1], inference_params):
  176. scores.append(get_logits(sequences[-1], inference_params))
  177. inference_params.seqlen_offset += sequences[-1].shape[1]
  178. sequences.append(sample_tokens(scores[-1], inference_params))
  179. if enable_timing:
  180. end.record()
  181. if tensor_parallel > 1:
  182. torch.distributed.barrier()
  183. torch.cuda.synchronize()
  184. print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
  185. output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
  186. return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
  187. def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
  188. """Algorithm 1 from [1]
  189. [1] Fast Inference from Transformers via Speculative Decoding
  190. Yaniv Leviathan, Matan Kalman, Yossi Matias
  191. https://arxiv.org/abs/2211.17192
  192. Arguments:
  193. logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
  194. logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
  195. tokens_draft: Tensor of shape (batch_size, seqlen)
  196. Return:
  197. tokens: Tensor of shape (batch_size, seqlen + 1)
  198. num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
  199. For each sequence in the batch, the number of valid tokens that were sampled by
  200. speculative sampling.
  201. """
  202. batch, seqlen_p_1, vocab_size = logits.shape
  203. seqlen = seqlen_p_1 - 1
  204. assert logits_draft.shape == (batch, seqlen, vocab_size)
  205. assert tokens_draft.shape == (batch, seqlen)
  206. assert tokens_draft.dtype in [torch.int64, torch.int32]
  207. # TODO: if top_k = 1 we can simplify things and only work with indices
  208. if top_p > 0.0:
  209. assert top_p <= 1.0, "top-p should be in (0, 1]."
  210. # Clone so that when we modify for top_p we don't change the original logits
  211. logits = logits / temperature if temperature != 1.0 else logits.clone()
  212. logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
  213. if top_k > 0:
  214. top_k = min(top_k, logits.size(-1)) # Safety check
  215. modify_logits_for_top_k_filtering(logits, top_k)
  216. modify_logits_for_top_k_filtering(logits_draft, top_k)
  217. modify_logits_for_top_p_filtering(logits, top_p)
  218. modify_logits_for_top_p_filtering(logits_draft, top_p)
  219. probs = torch.softmax(logits, dim=-1)
  220. probs_draft = torch.softmax(logits_draft, dim=-1)
  221. gather = lambda probs, tokens: rearrange(
  222. probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
  223. )
  224. # (batch, seqlen)
  225. accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
  226. probs_draft, tokens_draft
  227. ) <= gather(probs[:, :-1], tokens_draft)
  228. accepted_all = accepted.all(dim=-1)
  229. # (batch,)
  230. first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
  231. probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
  232. # torch.multinomial can deal with unnormalized probabilities
  233. # probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
  234. resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
  235. resample_probs = rearrange(
  236. resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
  237. "b 1 d -> b d",
  238. )
  239. resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,)
  240. tokens = F.pad(tokens_draft, (0, 1))
  241. tokens[:, first_rejected_idx] = resample
  242. return tokens, first_rejected_idx + 1
  243. @torch.inference_mode()
  244. def decode_speculative(
  245. input_ids,
  246. model,
  247. model_draft,
  248. max_length,
  249. speculative_lookahead=3,
  250. top_k=1,
  251. top_p=0.0,
  252. temperature=1.0,
  253. eos_token_id=None,
  254. vocab_size=None,
  255. tensor_parallel=1,
  256. cg=False,
  257. enable_timing=False,
  258. debug=False,
  259. ):
  260. """
  261. TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
  262. Speculative decoding, either greedy or with top-k or top-p sampling.
  263. If top-k = 0, don't limit the number of candidates (pure sampling).
  264. Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
  265. then top-p.
  266. We assume that all sequences in the same batch have the same length.
  267. Arguments:
  268. input_ids: (batch, seq_len)
  269. max_length: int
  270. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
  271. sequences: (batch, max_length)
  272. scores: tuples of (batch, vocab_size)
  273. """
  274. batch_size, seqlen_og = input_ids.shape
  275. assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
  276. assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
  277. if cg:
  278. if not hasattr(model_draft, "_decoding_cache"):
  279. model_draft._decoding_cache = None
  280. model_draft._decoding_cache = update_graph_cache(
  281. model_draft,
  282. model_draft._decoding_cache,
  283. batch_size,
  284. seqlen_og,
  285. max_length,
  286. # draft model needs to process either 1 or 2 tokens at a time
  287. decoding_seqlens=(1, 2),
  288. tensor_parallel=tensor_parallel,
  289. )
  290. inference_params_draft = model_draft._decoding_cache.inference_params
  291. inference_params_draft.reset(max_length, batch_size)
  292. if not hasattr(model, "_decoding_cache"):
  293. model._decoding_cache = None
  294. model._decoding_cache = update_graph_cache(
  295. model,
  296. model._decoding_cache,
  297. batch_size,
  298. seqlen_og,
  299. max_length,
  300. decoding_seqlens=range(1, speculative_lookahead + 2),
  301. tensor_parallel=tensor_parallel,
  302. )
  303. inference_params = model._decoding_cache.inference_params
  304. inference_params.reset(max_length, batch_size)
  305. else:
  306. inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
  307. inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
  308. def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):
  309. decoding = inference_params.seqlen_offset > 0
  310. if decoding:
  311. seqlen = input_ids.shape[1]
  312. # if inference_params.lengths_per_sample is None:
  313. # TODO: in the case of batched decoding where each sequence has a different length,
  314. # we need to compute the position_ids for each sequence using lengths_per_sample
  315. if True:
  316. cache_seqlens = torch.full(
  317. (input_ids.shape[0],),
  318. inference_params.seqlen_offset,
  319. dtype=torch.int32,
  320. device=input_ids.device,
  321. )
  322. else:
  323. cache_seqlens = inference_params.lengths_per_sample
  324. position_ids = cache_seqlens[:, None] + torch.arange(
  325. seqlen, dtype=torch.long, device=input_ids.device
  326. )
  327. else:
  328. position_ids = None
  329. if not cg or not decoding:
  330. logits = model(
  331. input_ids,
  332. position_ids=position_ids,
  333. inference_params=inference_params,
  334. num_last_tokens=num_last_tokens,
  335. ).logits
  336. else:
  337. # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
  338. # This might not be compatible the num_last_tokens used here.
  339. assert num_last_tokens <= input_ids.shape[1]
  340. logits = model._decoding_cache.run(
  341. input_ids, position_ids, inference_params.seqlen_offset
  342. )[:, -num_last_tokens:]
  343. return logits[..., :vocab_size] if vocab_size is not None else logits
  344. def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):
  345. """Sample `num_tokens` tokens from the model, given the previous logits.
  346. Also return the logits of the sampled tokens.
  347. Arguments:
  348. input_ids: (batch, seqlen)
  349. Return:
  350. tokens: (batch, num_tokens)
  351. scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
  352. (num_tokens - 1) tokens. The logits of the last token isn't computed.
  353. """
  354. assert num_tokens >= 1
  355. sequences, scores = [input_ids], []
  356. for i in range(num_tokens):
  357. scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])
  358. inference_params.seqlen_offset += sequences[-1].shape[1]
  359. sequences.append(sample_fn(scores[-1]).unsqueeze(1))
  360. return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)
  361. sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
  362. sample_fn = partial(sample, **sampling_kwargs)
  363. get_logits_main = partial(get_logits, model=model, cg=cg)
  364. get_logits_draft = partial(get_logits, model=model_draft, cg=cg)
  365. sample_tokens_main = partial(
  366. sample_tokens,
  367. get_logits_fn=get_logits_main,
  368. sample_fn=sample_fn,
  369. inference_params=inference_params,
  370. )
  371. sample_tokens_draft = partial(
  372. sample_tokens,
  373. get_logits_fn=get_logits_draft,
  374. sample_fn=sample_fn,
  375. inference_params=inference_params_draft,
  376. )
  377. if debug:
  378. from transformers import AutoTokenizer
  379. tokenizer = AutoTokenizer.from_pretrained("gpt2")
  380. if enable_timing:
  381. if tensor_parallel > 1:
  382. torch.distributed.barrier()
  383. torch.cuda.synchronize()
  384. start = time.time()
  385. sequences, scores = [input_ids], []
  386. num_main_model_calls = 0
  387. num_draft_tokens = 0
  388. num_accepted_tokens_history = []
  389. if seqlen_og >= max_length - 1:
  390. # Don't do speculative sampling, just sample 1 token from the model
  391. tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)
  392. sequences.append(tokens)
  393. scores.append(scores_new)
  394. else:
  395. # Sample from draft model, which produces @n_spec_tokens, and @model
  396. # will then use to produce between 1 and 1 + @n_spec_tokens tokens.
  397. # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
  398. n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
  399. tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)
  400. num_draft_tokens += n_spec_tokens
  401. if debug:
  402. scores_draft_ref = model_draft(
  403. torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
  404. ).logits
  405. print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
  406. # Evaluate the draft tokens with the model
  407. logits = get_logits_main(
  408. torch.cat([input_ids, tokens_draft], dim=1),
  409. inference_params,
  410. num_last_tokens=n_spec_tokens + 1,
  411. )
  412. num_main_model_calls += 1
  413. if debug:
  414. logits_ref = model(
  415. torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
  416. ).logits
  417. print((logits - logits_ref).abs().max())
  418. # breakpoint()
  419. tokens, num_generated_tokens = sample_speculative(
  420. logits, scores_draft, tokens_draft, **sampling_kwargs
  421. )
  422. num_accepted_tokens_history.append(num_generated_tokens - 1)
  423. if debug:
  424. print(tokens)
  425. print(num_generated_tokens)
  426. # breakpoint()
  427. # TODO: we're using the fact that batch_size == 1
  428. # TODO: check eos_token_id
  429. sequences.append(tokens[:1, : num_generated_tokens[0]])
  430. scores.append(logits[:1, : num_generated_tokens[0]])
  431. # Note that @model has not evaluated the last sampled token yet, so we'll need to pass
  432. # that in the next time we call @model.
  433. num_generated = num_generated_tokens[0].item()
  434. inference_params.seqlen_offset = seqlen_og + num_generated - 1
  435. inference_params_draft.seqlen_offset = (
  436. inference_params.seqlen_offset - 1
  437. if num_generated > 1
  438. else inference_params.seqlen_offset
  439. )
  440. if debug:
  441. cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
  442. scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
  443. print((scores[-1] - scores_ref[:, :-1]).abs().max())
  444. # breakpoint()
  445. while True:
  446. # seqlen_offset is total length generated - 1
  447. if inference_params.seqlen_offset >= max_length - 1:
  448. break
  449. if inference_params.seqlen_offset >= max_length - 2:
  450. # Don't do speculative sampling, just sample 1 token from the model
  451. tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
  452. sequences.append(tokens)
  453. scores.append(scores_new)
  454. break
  455. # Sample from draft model
  456. n_spec_tokens = min(
  457. speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
  458. )
  459. # If the main model accepts all the draft tokens, plus it samples one new token,
  460. # then at the next iteration the draft model need to evaluate the logits of the last draft
  461. # token and the logits of the newly sampled token. So here we pass in the last 2 tokens
  462. # of sequences[-1].
  463. # This exception is when the main model rejects all the draft tokens, in which case we
  464. # will only have 1 token to pass in.
  465. tokens_draft, scores_draft = sample_tokens_draft(
  466. sequences[-1][:, -2:], num_tokens=n_spec_tokens
  467. )
  468. num_draft_tokens += n_spec_tokens
  469. if debug:
  470. scores_draft_ref = model_draft(
  471. torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
  472. ).logits
  473. print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
  474. # breakpoint()
  475. # Evaluate the draft tokens with the model
  476. logits = get_logits_main(
  477. torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
  478. inference_params,
  479. num_last_tokens=n_spec_tokens + 1,
  480. ) # (batch, n_spec_tokens + 1, vocab_size)
  481. num_main_model_calls += 1
  482. if debug:
  483. logits_ref = model(
  484. torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
  485. ).logits
  486. print((logits - logits_ref).abs().max())
  487. # breakpoint()
  488. tokens, num_generated_tokens = sample_speculative(
  489. logits, scores_draft, tokens_draft, **sampling_kwargs
  490. )
  491. num_accepted_tokens_history.append(num_generated_tokens - 1)
  492. if debug:
  493. print(tokens)
  494. print(num_generated_tokens)
  495. # breakpoint()
  496. sequences.append(tokens[:1, : num_generated_tokens[0]])
  497. scores.append(logits[:1, : num_generated_tokens[0]])
  498. # We've evaluated 1 token from sequences[-1][:, -1:] above, plus
  499. # num_generated_tokens[0].item() - 1 tokens from the draft model.
  500. num_generated = num_generated_tokens[0].item()
  501. inference_params.seqlen_offset += num_generated
  502. inference_params_draft.seqlen_offset = (
  503. inference_params.seqlen_offset - 1
  504. if num_generated > 1
  505. else inference_params.seqlen_offset
  506. )
  507. if debug:
  508. cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
  509. scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
  510. print((scores[-1] - scores_ref[:, :-1]).abs().max())
  511. # breakpoint()
  512. if enable_timing:
  513. if tensor_parallel > 1:
  514. torch.distributed.barrier()
  515. torch.cuda.synchronize()
  516. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  517. print(f"Number of calls to main model: {num_main_model_calls}")
  518. print(
  519. f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%"
  520. )
  521. sequences = torch.cat(sequences, dim=1)
  522. scores = torch.cat(scores, dim=1)
  523. if debug:
  524. scores_ref = model(sequences).logits
  525. print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
  526. output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
  527. return output_cls(sequences=sequences, scores=scores)
  528. class GenerationMixin:
  529. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  530. raise NotImplementedError
  531. def generate(
  532. self,
  533. input_ids,
  534. max_length,
  535. top_k=1,
  536. top_p=0.0,
  537. temperature=1.0,
  538. return_dict_in_generate=False,
  539. output_scores=False,
  540. **kwargs,
  541. ):
  542. output = decode(
  543. input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
  544. )
  545. if not output_scores:
  546. output.scores = None
  547. return output if return_dict_in_generate else output.sequences
  548. def allocate_inference_cache(
  549. max_batch_size,
  550. max_seqlen,
  551. nheads,
  552. headdim,
  553. layers: Union[int, Sequence],
  554. device,
  555. dtype=torch.float16,
  556. ):
  557. assert dtype in [torch.float16, torch.bfloat16, torch.float32]
  558. kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
  559. if isinstance(layers, int):
  560. layers = range(layers)
  561. return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
  562. @dataclass
  563. class DecodingCGCache:
  564. max_batch_size: int = 0
  565. max_seqlen: int = 0
  566. device = None
  567. dtype = None
  568. callables: dict = field(default_factory=dict)
  569. mempool = None
  570. inference_params: Optional[InferenceParams] = None
  571. run: Optional[Callable] = None
  572. @torch.inference_mode()
  573. def update_graph_cache(
  574. model,
  575. cache,
  576. batch_size,
  577. seqlen_og,
  578. max_seqlen,
  579. decoding_seqlens=(1,),
  580. tensor_parallel=1,
  581. dtype=None,
  582. n_warmups=2,
  583. ):
  584. if cache is None:
  585. cache = DecodingCGCache()
  586. param_example = next(iter(model.parameters()))
  587. device = param_example.device
  588. if dtype is None:
  589. dtype = param_example.dtype
  590. if (
  591. (device, dtype) != (cache.device, cache.dtype)
  592. or batch_size > cache.max_batch_size
  593. or max_seqlen > cache.max_seqlen
  594. ): # Invalidate the cache
  595. cache.callables = {}
  596. cache.mempool = None
  597. cache.inference_params = None
  598. gc.collect()
  599. cache.device, cache.dtype = device, dtype
  600. cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
  601. if hasattr(model, "allocate_inference_cache"):
  602. inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
  603. else:
  604. headdim = getattr(
  605. model.config,
  606. "head_dim",
  607. model.config.hidden_size // model.config.num_attention_heads,
  608. )
  609. inf_cache = allocate_inference_cache(
  610. batch_size,
  611. max_seqlen,
  612. model.config.num_attention_heads // tensor_parallel,
  613. headdim,
  614. model.config.num_hidden_layers,
  615. device,
  616. dtype,
  617. )
  618. lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
  619. cache.inference_params = InferenceParams(
  620. max_seqlen=max_seqlen,
  621. max_batch_size=batch_size,
  622. seqlen_offset=seqlen_og,
  623. key_value_memory_dict=inf_cache,
  624. lengths_per_sample=lengths_per_sample,
  625. )
  626. cache.mempool = torch.cuda.graphs.graph_pool_handle()
  627. for decoding_seqlen in decoding_seqlens:
  628. if (batch_size, decoding_seqlen) not in cache.callables:
  629. cache.callables[batch_size, decoding_seqlen] = capture_graph(
  630. model,
  631. cache.inference_params,
  632. batch_size,
  633. max_seqlen,
  634. decoding_seqlen=decoding_seqlen,
  635. mempool=cache.mempool,
  636. n_warmups=n_warmups,
  637. )
  638. def dispatch(input_ids, position_ids, seqlen):
  639. batch_size, decoding_seqlen = input_ids.shape[:2]
  640. return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
  641. cache.run = dispatch
  642. cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
  643. return cache
  644. def capture_graph(
  645. model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
  646. ):
  647. device = next(iter(model.parameters())).device
  648. input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
  649. position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
  650. seqlen_offset_og = inference_params.seqlen_offset
  651. inference_params.seqlen_offset = max_seqlen - decoding_seqlen
  652. inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
  653. # Warmup before capture
  654. s = torch.cuda.Stream()
  655. s.wait_stream(torch.cuda.current_stream())
  656. with torch.cuda.stream(s):
  657. for _ in range(n_warmups):
  658. logits = model(
  659. input_ids,
  660. position_ids=position_ids,
  661. inference_params=inference_params,
  662. num_last_tokens=decoding_seqlen,
  663. ).logits
  664. s.synchronize()
  665. # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
  666. # which requires that graph launch and non-captured launch to not overlap (I think,
  667. # that's how I interpret the documentation). I'm not sure if this is required.
  668. if torch.distributed.is_initialized():
  669. torch.distributed.barrier()
  670. torch.cuda.current_stream().wait_stream(s)
  671. # Captures the graph
  672. # To allow capture, automatically sets a side stream as the current stream in the context
  673. graph = torch.cuda.CUDAGraph()
  674. with torch.cuda.graph(graph, pool=mempool):
  675. logits = model(
  676. input_ids,
  677. position_ids=position_ids,
  678. inference_params=inference_params,
  679. num_last_tokens=decoding_seqlen,
  680. ).logits
  681. def run(new_input_ids, new_position_ids, seqlen):
  682. inference_params.lengths_per_sample[:] = seqlen
  683. input_ids.copy_(new_input_ids)
  684. position_ids.copy_(new_position_ids)
  685. graph.replay()
  686. return logits.clone()
  687. inference_params.seqlen_offset = seqlen_offset_og
  688. return run