# Copyright (c) 2023, Tri Dao. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 import gc import time from collections import namedtuple from dataclasses import dataclass, field from functools import partial from typing import Callable, Optional, Sequence, Union import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function try: from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput except ImportError: GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) @dataclass class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" max_seqlen: int max_batch_size: int seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None def reset(self, max_seqlen, max_batch_size): self.max_seqlen = max_seqlen self.max_batch_size = max_batch_size self.seqlen_offset = 0 if self.lengths_per_sample is not None: self.lengths_per_sample.zero_() # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 def modify_logits_for_top_k_filtering(logits, top_k): """Set the logits for none top-k values to -inf. Done in-place.""" indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(indices_to_remove, float("-Inf")) # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf. Done in-place.""" if top_p <= 0.0 or top_p >= 1.0: return # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits.masked_fill_(indices_to_remove, float("-inf")) def sample(logits, top_k=1, top_p=0.0, temperature=1.0): """Sample from top-k logits. Arguments: logits: Tensor of shape (batch_size, vocab_size) """ if top_k == 1: # Short-circuit for greedy decoding return logits.argmax(dim=-1) else: if top_p > 0.0: assert top_p <= 1.0, "top-p should be in (0, 1]." if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check logits_top, indices = torch.topk(logits, top_k, dim=-1) if temperature != 1.0: logits_top /= temperature modify_logits_for_top_p_filtering(logits_top, top_p) return indices[ torch.arange(indices.shape[0], device=indices.device), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), ] else: # Clone so that when we modify for top_p we don't change the original logits logits_top = logits / temperature if temperature != 1.0 else logits.clone() modify_logits_for_top_p_filtering(logits_top, top_p) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( dim=-1 ) @torch.inference_mode() def decode( input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1, cg=False, enable_timing=False, ): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the logits, the next token is taken from the teacher_outputs. Useful for testing. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 if cg: if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( model, model._decoding_cache, batch_size, seqlen_og, max_length, tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params inference_params.reset(max_length, batch_size) else: inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params): decoding = inference_params.seqlen_offset > 0 if decoding: position_ids = torch.full( (batch_size, 1), inference_params.seqlen_offset, dtype=torch.long, device=input_ids.device, ) else: position_ids = None if not cg or not decoding: logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=1, ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset ).squeeze(dim=1) return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: token = teacher_outputs[:, inference_params.seqlen_offset] # return rearrange(token, "b -> b 1") return token.unsqueeze(1) def should_stop(current_token, inference_params): if inference_params.seqlen_offset == 0: return False if eos_token_id is not None and (current_token == eos_token_id).all(): return True if inference_params.seqlen_offset >= max_length - 1: return True return False start = torch.cuda.Event(enable_timing=enable_timing) end = torch.cuda.Event(enable_timing=enable_timing) if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() start.record() scores, sequences = [], [input_ids] while not should_stop(sequences[-1], inference_params): scores.append(get_logits(sequences[-1], inference_params)) inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_tokens(scores[-1], inference_params)) if enable_timing: end.record() if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): """Algorithm 1 from [1] [1] Fast Inference from Transformers via Speculative Decoding Yaniv Leviathan, Matan Kalman, Yossi Matias https://arxiv.org/abs/2211.17192 Arguments: logits: Tensor of shape (batch_size, seqlen + 1, vocab_size) logits_draft: Tensor of shape (batch_size, seqlen, vocab_size) tokens_draft: Tensor of shape (batch_size, seqlen) Return: tokens: Tensor of shape (batch_size, seqlen + 1) num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1]. For each sequence in the batch, the number of valid tokens that were sampled by speculative sampling. """ batch, seqlen_p_1, vocab_size = logits.shape seqlen = seqlen_p_1 - 1 assert logits_draft.shape == (batch, seqlen, vocab_size) assert tokens_draft.shape == (batch, seqlen) assert tokens_draft.dtype in [torch.int64, torch.int32] # TODO: if top_k = 1 we can simplify things and only work with indices if top_p > 0.0: assert top_p <= 1.0, "top-p should be in (0, 1]." # Clone so that when we modify for top_p we don't change the original logits logits = logits / temperature if temperature != 1.0 else logits.clone() logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone() if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check modify_logits_for_top_k_filtering(logits, top_k) modify_logits_for_top_k_filtering(logits_draft, top_k) modify_logits_for_top_p_filtering(logits, top_p) modify_logits_for_top_p_filtering(logits_draft, top_p) probs = torch.softmax(logits, dim=-1) probs_draft = torch.softmax(logits_draft, dim=-1) gather = lambda probs, tokens: rearrange( probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..." ) # (batch, seqlen) accepted = torch.rand(batch, seqlen, device=probs.device) * gather( probs_draft, tokens_draft ) <= gather(probs[:, :-1], tokens_draft) accepted_all = accepted.all(dim=-1) # (batch,) first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1)) probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0) # torch.multinomial can deal with unnormalized probabilities # probs_diff /= probs_diff.sum(dim=-1, keepdim=True) resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1) resample_probs = rearrange( resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)), "b 1 d -> b d", ) resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,) tokens = F.pad(tokens_draft, (0, 1)) tokens[:, first_rejected_idx] = resample return tokens, first_rejected_idx + 1 @torch.inference_mode() def decode_speculative( input_ids, model, model_draft, max_length, speculative_lookahead=3, top_k=1, top_p=0.0, temperature=1.0, eos_token_id=None, vocab_size=None, tensor_parallel=1, cg=False, enable_timing=False, debug=False, ): """ TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now. Speculative decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" if cg: if not hasattr(model_draft, "_decoding_cache"): model_draft._decoding_cache = None model_draft._decoding_cache = update_graph_cache( model_draft, model_draft._decoding_cache, batch_size, seqlen_og, max_length, # draft model needs to process either 1 or 2 tokens at a time decoding_seqlens=(1, 2), tensor_parallel=tensor_parallel, ) inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft.reset(max_length, batch_size) if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( model, model._decoding_cache, batch_size, seqlen_og, max_length, decoding_seqlens=range(1, speculative_lookahead + 2), tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params inference_params.reset(max_length, batch_size) else: inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): decoding = inference_params.seqlen_offset > 0 if decoding: seqlen = input_ids.shape[1] # if inference_params.lengths_per_sample is None: # TODO: in the case of batched decoding where each sequence has a different length, # we need to compute the position_ids for each sequence using lengths_per_sample if True: cache_seqlens = torch.full( (input_ids.shape[0],), inference_params.seqlen_offset, dtype=torch.int32, device=input_ids.device, ) else: cache_seqlens = inference_params.lengths_per_sample position_ids = cache_seqlens[:, None] + torch.arange( seqlen, dtype=torch.long, device=input_ids.device ) else: position_ids = None if not cg or not decoding: logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=num_last_tokens, ).logits else: # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. # This might not be compatible the num_last_tokens used here. assert num_last_tokens <= input_ids.shape[1] logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset )[:, -num_last_tokens:] return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): """Sample `num_tokens` tokens from the model, given the previous logits. Also return the logits of the sampled tokens. Arguments: input_ids: (batch, seqlen) Return: tokens: (batch, num_tokens) scores: (batch, num_tokens), which contains @previous_logits and the logits of the next (num_tokens - 1) tokens. The logits of the last token isn't computed. """ assert num_tokens >= 1 sequences, scores = [input_ids], [] for i in range(num_tokens): scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_fn(scores[-1]).unsqueeze(1)) return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) sample_fn = partial(sample, **sampling_kwargs) get_logits_main = partial(get_logits, model=model, cg=cg) get_logits_draft = partial(get_logits, model=model_draft, cg=cg) sample_tokens_main = partial( sample_tokens, get_logits_fn=get_logits_main, sample_fn=sample_fn, inference_params=inference_params, ) sample_tokens_draft = partial( sample_tokens, get_logits_fn=get_logits_draft, sample_fn=sample_fn, inference_params=inference_params_draft, ) if debug: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() start = time.time() sequences, scores = [input_ids], [] num_main_model_calls = 0 num_draft_tokens = 0 num_accepted_tokens_history = [] if seqlen_og >= max_length - 1: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) sequences.append(tokens) scores.append(scores_new) else: # Sample from draft model, which produces @n_spec_tokens, and @model # will then use to produce between 1 and 1 + @n_spec_tokens tokens. # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) num_draft_tokens += n_spec_tokens if debug: scores_draft_ref = model_draft( torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) # Evaluate the draft tokens with the model logits = get_logits_main( torch.cat([input_ids, tokens_draft], dim=1), inference_params, num_last_tokens=n_spec_tokens + 1, ) num_main_model_calls += 1 if debug: logits_ref = model( torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((logits - logits_ref).abs().max()) # breakpoint() tokens, num_generated_tokens = sample_speculative( logits, scores_draft, tokens_draft, **sampling_kwargs ) num_accepted_tokens_history.append(num_generated_tokens - 1) if debug: print(tokens) print(num_generated_tokens) # breakpoint() # TODO: we're using the fact that batch_size == 1 # TODO: check eos_token_id sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) # Note that @model has not evaluated the last sampled token yet, so we'll need to pass # that in the next time we call @model. num_generated = num_generated_tokens[0].item() inference_params.seqlen_offset = seqlen_og + num_generated - 1 inference_params_draft.seqlen_offset = ( inference_params.seqlen_offset - 1 if num_generated > 1 else inference_params.seqlen_offset ) if debug: cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits print((scores[-1] - scores_ref[:, :-1]).abs().max()) # breakpoint() while True: # seqlen_offset is total length generated - 1 if inference_params.seqlen_offset >= max_length - 1: break if inference_params.seqlen_offset >= max_length - 2: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) sequences.append(tokens) scores.append(scores_new) break # Sample from draft model n_spec_tokens = min( speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 ) # If the main model accepts all the draft tokens, plus it samples one new token, # then at the next iteration the draft model need to evaluate the logits of the last draft # token and the logits of the newly sampled token. So here we pass in the last 2 tokens # of sequences[-1]. # This exception is when the main model rejects all the draft tokens, in which case we # will only have 1 token to pass in. tokens_draft, scores_draft = sample_tokens_draft( sequences[-1][:, -2:], num_tokens=n_spec_tokens ) num_draft_tokens += n_spec_tokens if debug: scores_draft_ref = model_draft( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) # breakpoint() # Evaluate the draft tokens with the model logits = get_logits_main( torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), inference_params, num_last_tokens=n_spec_tokens + 1, ) # (batch, n_spec_tokens + 1, vocab_size) num_main_model_calls += 1 if debug: logits_ref = model( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((logits - logits_ref).abs().max()) # breakpoint() tokens, num_generated_tokens = sample_speculative( logits, scores_draft, tokens_draft, **sampling_kwargs ) num_accepted_tokens_history.append(num_generated_tokens - 1) if debug: print(tokens) print(num_generated_tokens) # breakpoint() sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) # We've evaluated 1 token from sequences[-1][:, -1:] above, plus # num_generated_tokens[0].item() - 1 tokens from the draft model. num_generated = num_generated_tokens[0].item() inference_params.seqlen_offset += num_generated inference_params_draft.seqlen_offset = ( inference_params.seqlen_offset - 1 if num_generated > 1 else inference_params.seqlen_offset ) if debug: cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits print((scores[-1] - scores_ref[:, :-1]).abs().max()) # breakpoint() if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") print(f"Number of calls to main model: {num_main_model_calls}") print( f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" ) sequences = torch.cat(sequences, dim=1) scores = torch.cat(scores, dim=1) if debug: scores_ref = model(sequences).logits print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max()) output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls(sequences=sequences, scores=scores) class GenerationMixin: def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): raise NotImplementedError def generate( self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, **kwargs, ): output = decode( input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs ) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences def allocate_inference_cache( max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], device, dtype=torch.float16, ): assert dtype in [torch.float16, torch.bfloat16, torch.float32] kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) if isinstance(layers, int): layers = range(layers) return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} @dataclass class DecodingCGCache: max_batch_size: int = 0 max_seqlen: int = 0 device = None dtype = None callables: dict = field(default_factory=dict) mempool = None inference_params: Optional[InferenceParams] = None run: Optional[Callable] = None @torch.inference_mode() def update_graph_cache( model, cache, batch_size, seqlen_og, max_seqlen, decoding_seqlens=(1,), tensor_parallel=1, dtype=None, n_warmups=2, ): if cache is None: cache = DecodingCGCache() param_example = next(iter(model.parameters())) device = param_example.device if dtype is None: dtype = param_example.dtype if ( (device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size or max_seqlen > cache.max_seqlen ): # Invalidate the cache cache.callables = {} cache.mempool = None cache.inference_params = None gc.collect() cache.device, cache.dtype = device, dtype cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen if hasattr(model, "allocate_inference_cache"): inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) else: headdim = getattr( model.config, "head_dim", model.config.hidden_size // model.config.num_attention_heads, ) inf_cache = allocate_inference_cache( batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, model.config.num_hidden_layers, device, dtype, ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( max_seqlen=max_seqlen, max_batch_size=batch_size, seqlen_offset=seqlen_og, key_value_memory_dict=inf_cache, lengths_per_sample=lengths_per_sample, ) cache.mempool = torch.cuda.graphs.graph_pool_handle() for decoding_seqlen in decoding_seqlens: if (batch_size, decoding_seqlen) not in cache.callables: cache.callables[batch_size, decoding_seqlen] = capture_graph( model, cache.inference_params, batch_size, max_seqlen, decoding_seqlen=decoding_seqlen, mempool=cache.mempool, n_warmups=n_warmups, ) def dispatch(input_ids, position_ids, seqlen): batch_size, decoding_seqlen = input_ids.shape[:2] return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) cache.run = dispatch cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache def capture_graph( model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 ): device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) seqlen_offset_og = inference_params.seqlen_offset inference_params.seqlen_offset = max_seqlen - decoding_seqlen inference_params.lengths_per_sample[:] = inference_params.seqlen_offset # Warmup before capture s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # which requires that graph launch and non-captured launch to not overlap (I think, # that's how I interpret the documentation). I'm not sure if this is required. if torch.distributed.is_initialized(): torch.distributed.barrier() torch.cuda.current_stream().wait_stream(s) # Captures the graph # To allow capture, automatically sets a side stream as the current stream in the context graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=mempool): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen input_ids.copy_(new_input_ids) position_ids.copy_(new_position_ids) graph.replay() return logits.clone() inference_params.seqlen_offset = seqlen_offset_og return run