Browse Source

[GPT] Implement Falcon

Tri Dao 1 year ago
parent
commit
d38357dd2f
3 changed files with 495 additions and 0 deletions
  1. 122 0
      flash_attn/models/falcon.py
  2. 3 0
      flash_attn/models/gpt.py
  3. 370 0
      tests/models/test_falcon.py

+ 122 - 0
flash_attn/models/falcon.py

@@ -0,0 +1,122 @@
+# Copyright (c) 2023, Tri Dao.
+
+import math
+import re
+
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+
+from einops import rearrange
+
+from transformers import GPT2Config, FalconConfig
+
+
+def remap_state_dict_hf_falcon(state_dict, config):
+    def key_mapping_layers(key):
+        return re.sub(r'^transformer.h.', 'transformer.layers.', key)
+    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
+    # Word embedding
+    def key_mapping_emb(key):
+        return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
+    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
+    word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
+    # It's possible that vocab_size is padded to be a multiple of 8, for example.
+    pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
+    vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
+    state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
+        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
+    )
+    if getattr(config, 'tie_word_embeddings'):
+        state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
+    else:
+        output_embeddings = state_dict.pop('lm_head.weight')
+        # It's possible that vocab_size is padded to be a multiple of 8, for example.
+        state_dict['lm_head.weight'] = F.pad(
+            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
+        )
+        output_embeddings_bias = state_dict.pop('lm_head.bias')
+        state_dict['lm_head.bias'] = F.pad(
+            output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
+        )
+
+    # LayerNorm
+    def key_mapping_ln(key):
+        key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
+                     r'transformer.layers.\1.norm1.', key)
+        key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
+                     r'transformer.layers.\1.norm2.', key)
+        key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
+        key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
+        return key
+    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
+
+    # MLP
+    def key_mapping_mlp(key):
+        key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
+                     r'transformer.layers.\1.mlp.fc1.', key)
+        key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
+                     r'transformer.layers.\1.mlp.fc2.', key)
+        return key
+    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
+
+    def key_mapping_attn(key):
+        key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
+                      r'transformer.layers.\1.mixer.Wqkv.', key)
+        key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
+                      r'transformer.layers.\1.mixer.out_proj.', key)
+        return key
+    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
+    n_head = config.n_head
+    n_head_kv = getattr(config, "n_head_kv", 1)
+    headdim = config.hidden_size // n_head
+    for l in range(config.n_layer):
+        # The weights are stored in a different layout compared to our implementation
+        Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
+                         "(group ratio headdim) ... -> group ratio headdim ...",
+                         ratio=n_head // n_head_kv + 2, headdim=headdim)
+        Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
+        Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
+        Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
+        state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
+
+    return state_dict
+
+
+def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
+    # The 40b config uses "n_head_kv" instead of "num_kv_heads"
+    n_head_kv = getattr(falcon_config, "n_head_kv",
+                        1 if getattr(falcon_config, "multi_query", False)
+                        else falcon_config.n_head)
+    # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
+    # So we have to infer it from the number of heads in the key/value block
+    parallel_block_tied_norm = n_head_kv == 1
+    return GPT2Config(
+        vocab_size=falcon_config.vocab_size,
+        n_positions=0,  # No absolute position embedding
+        n_embd=falcon_config.hidden_size,
+        n_layer=falcon_config.n_layer,
+        n_head=falcon_config.n_head,
+        n_inner=falcon_config.hidden_size * 4,
+        activation_function="gelu",
+        resid_pdrop=falcon_config.hidden_dropout,
+        embd_pdrop=0.0,  # There doesn't seem to be any embedding dropout
+        attn_pdrop=falcon_config.attention_dropout,
+        layer_norm_epsilon=falcon_config.layer_norm_epsilon,
+        initializer_range=falcon_config.initializer_range,
+        bos_token_id=falcon_config.bos_token_id,
+        eos_token_id=falcon_config.eos_token_id,
+        # These are new arguments not in the original GPT2Config
+        parallel_block=falcon_config.parallel_attn,
+        n_head_kv=n_head_kv,
+        parallel_block_tied_norm=parallel_block_tied_norm,
+        rotary_emb_fraction=1.0,
+        rotary_emb_interleaved=False,
+        tie_word_embeddings=True,
+        qkv_proj_bias=falcon_config.bias,
+        out_proj_bias=falcon_config.bias,
+        mlp_fc1_bias=falcon_config.bias,
+        mlp_fc2_bias=falcon_config.bias,
+        lm_head_bias=False,
+    )

+ 3 - 0
flash_attn/models/gpt.py

@@ -27,6 +27,7 @@ from flash_attn.utils.generation import GenerationMixin
 from flash_attn.models.opt import remap_state_dict_hf_opt
 from flash_attn.models.gptj import remap_state_dict_hf_gptj
 from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
+from flash_attn.models.falcon import remap_state_dict_hf_falcon
 
 try:
     from flash_attn.ops.fused_dense import ColumnParallelLinear
@@ -241,6 +242,8 @@ class GPTPreTrainedModel(nn.Module):
             state_dict = remap_state_dict_hf_gptj(state_dict, config)
         elif model_name.startswith('EleutherAI/gpt-neox-'):
             state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
+        elif model_name.startswith('tiiuae/falcon-'):
+            state_dict = remap_state_dict_hf_falcon(state_dict, config)
         else:
             raise NotImplementedError(f'Model {model_name} not supported')
         if world_size > 1:

+ 370 - 0
tests/models/test_falcon.py

@@ -0,0 +1,370 @@
+# Copyright (c) 2023, Tri Dao.
+
+import os
+import time
+from pathlib import Path
+current_dir = Path(__file__).parent.absolute()
+
+import torch
+import pytest
+
+from einops import rearrange
+
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
+
+from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
+from flash_attn.models.falcon import remap_state_dict_hf_falcon, falcon_config_to_gpt2_config
+from flash_attn.utils.distributed import all_gather_raw
+from flash_attn.utils.pretrained import state_dict_from_pretrained
+from flash_attn.utils.generation import update_graph_cache
+
+
+@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
+def test_falcon_state_dict(model_name):
+    config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
+                                                                     trust_remote_code=True))
+    pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
+    model = GPTLMHeadModel(config, device='meta')  # Without device='meta' init is very slow
+    state_dict = model.state_dict()
+    assert state_dict.keys() == pretrained_state_dict.keys()
+    for k in state_dict.keys():
+        assert state_dict[k].shape == pretrained_state_dict[k].shape
+
+
+@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
+def test_falcon_optimized(model_name):
+    """Check that our implementation (with all optimizations enabled) matches the
+    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
+    forward pass in fp16, when compared to the HF forward pass in fp32.
+    """
+    dtype = torch.float16
+    device = 'cuda'
+    config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
+                                                                     trust_remote_code=True))
+    config.use_flash_attn = True
+    config.fused_bias_fc = True
+    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
+    config.fused_dropout_add_ln = True
+    config.residual_in_fp32 = True
+
+    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
+    model.eval()
+
+    torch.manual_seed(0)
+    batch_size = 2
+    max_seqlen = 256
+    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
+                              device=device)
+    with torch.no_grad():
+        out = model.transformer(input_ids)
+        logits = model(input_ids).logits
+    del model
+
+    # Without device_map, the model is loaded on the CPU, which is very slow
+    model_ref = AutoModelForCausalLM.from_pretrained(
+        model_name, device_map={"": device}, trust_remote_code=True
+    )
+    model_ref.eval()
+    with torch.no_grad():
+        out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
+        logits_ref = model_ref(input_ids).logits.to(device=device)
+    del model_ref
+
+    model_hf = AutoModelForCausalLM.from_pretrained(
+        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
+    )
+    model_hf.eval()
+    out_hf = model_hf.transformer(input_ids).last_hidden_state
+    logits_hf = model_hf(input_ids).logits
+    del model_hf
+
+    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
+    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
+    print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
+    print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
+    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
+
+    print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
+    print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
+    print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
+    print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
+    assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
+
+
+# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
+# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
+# memory to run the model in fp32.
+@pytest.mark.parametrize('world_size', [4])
+@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
+def test_falcon_parallel_forward(model_name, world_size):
+    from apex.transformer import parallel_state
+
+    dtype = torch.float16
+    config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
+                                                                     trust_remote_code=True))
+    config.use_flash_attn = False
+    config.fused_bias_fc = True
+    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
+    config.fused_dropout_add_ln = False
+    config.residual_in_fp32 = True
+
+    if not torch.distributed.is_initialized():
+        torch.distributed.init_process_group(backend='nccl', init_method='env://')
+    device = f'cuda:{torch.distributed.get_rank()}'
+    assert world_size <= torch.distributed.get_world_size()
+    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
+    rank = parallel_state.get_tensor_model_parallel_rank()
+    process_group = parallel_state.get_tensor_model_parallel_group()
+
+    pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
+
+    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
+    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
+    model.eval()
+
+    torch.manual_seed(0)
+    batch_size = 2
+    max_seqlen = 256
+    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
+    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
+                              device=device)
+    with torch.no_grad():
+        out = model.transformer(input_ids)
+        out, _ = all_gather_raw(out, process_group=process_group)
+        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
+        logits = model(input_ids).logits
+        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
+        logits, _ = all_gather_raw(logits, process_group)
+        logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
+    del model
+
+    if rank == 0:
+        model_hf = AutoModelForCausalLM.from_pretrained(
+            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
+        )
+        model_hf.eval()
+        out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
+        logits_hf = model_hf(input_ids).logits.to(device=device)
+        del model_hf
+
+        # Without device_map, the model is loaded on the CPU, which is very slow
+        model_ref = AutoModelForCausalLM.from_pretrained(
+            model_name, device_map="auto", trust_remote_code=True
+        )
+        model_ref.eval()
+        with torch.no_grad():
+            out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
+            logits_ref = model_ref(input_ids).logits.to(device=device)
+        del model_ref
+
+        print(f'Output max diff: {(out - out_ref).abs().max().item()}')
+        print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
+        print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
+        print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
+        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
+
+        print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
+        print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
+        print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
+        print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
+        assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
+
+
+@pytest.mark.parametrize('model_name', ["tiiuae/falcon-7b"])
+def test_falcon_generation(model_name):
+    """Check that our implementation (with all optimizations enabled) matches the
+    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
+    forward pass in fp16, when compared to the HF forward pass in fp32.
+    """
+    dtype = torch.float16
+    device = 'cuda'
+    config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
+                                                                     trust_remote_code=True))
+    config.use_flash_attn = True
+    config.fused_bias_fc = True
+    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
+    config.fused_dropout_add_ln = True
+    config.residual_in_fp32 = True
+
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    eos_token_id = tokenizer.eos_token_id
+
+    torch.manual_seed(0)
+    batch_size = 1
+    seqlen = 100
+    max_length = 150
+    input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
+                              device=device)
+
+    model_hf = AutoModelForCausalLM.from_pretrained(
+        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
+    )
+    model_hf.eval()
+    print("HF fp16")
+    torch.cuda.synchronize()
+    start = time.time()
+    out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
+                               return_dict_in_generate=True, output_scores=True)
+    torch.cuda.synchronize()
+    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
+    del model_hf
+
+    model_ref = AutoModelForCausalLM.from_pretrained(
+        model_name, device_map={"": device}, trust_remote_code=True
+    )
+    model_ref.eval()
+    with torch.no_grad():
+        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
+    del model_ref
+
+    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
+    model.eval()
+
+    print('Without CUDA graph')
+    torch.cuda.synchronize()
+    start = time.time()
+    out = model.generate(input_ids=input_ids, max_length=max_length,
+                         eos_token_id=eos_token_id, fused_ft_kernel=True,
+                         return_dict_in_generate=True, output_scores=True, timing=True,
+                         teacher_outputs=out_hf.sequences)
+    torch.cuda.synchronize()
+    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
+
+    # Capture graph outside the timing loop
+    batch_size, seqlen_og = input_ids.shape
+    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
+    print('With CUDA graph')
+    torch.cuda.synchronize()
+    start = time.time()
+    out_cg = model.generate(input_ids=input_ids, max_length=max_length,
+                            fused_ft_kernel=True, cg=True,
+                            return_dict_in_generate=True, output_scores=True, timing=True,
+                            teacher_outputs=out_hf.sequences)
+    torch.cuda.synchronize()
+    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
+
+    with torch.no_grad():
+        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
+    logits_hf = torch.stack(out_hf.scores, dim=1)
+    logits = torch.stack(out.scores, dim=1)
+    logits_cg = torch.stack(out_cg.scores, dim=1)
+
+    del model
+
+    hf_error = (logits_hf - logits_ref).abs().max().item()
+    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
+
+    print(f'HF fp16 logits max diff: {hf_error}')
+    print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
+    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
+    print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
+    assert torch.equal(logits_cg, logits)
+
+
+# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
+# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
+# memory to run the model in fp32.
+@pytest.mark.parametrize('world_size', [4])
+@pytest.mark.parametrize('model_name', ["tiiuae/falcon-40b"])
+def test_falcon_parallel_generation(model_name, world_size):
+    """Check that our implementation matches the HF implementation:
+    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
+    the HF scores in fp32.
+    """
+    from apex.transformer import parallel_state
+
+    dtype = torch.float16
+    config = falcon_config_to_gpt2_config(AutoConfig.from_pretrained(model_name,
+                                                                     trust_remote_code=True))
+    config.use_flash_attn = False
+    config.fused_bias_fc = True
+    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
+    config.fused_dropout_add_ln = False
+    config.residual_in_fp32 = True
+    config.pad_vocab_size_multiple = 8 * world_size
+    config.sequence_parallel = False  # Need to set this to False for generation
+
+    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
+    if not torch.distributed.is_initialized():
+        torch.distributed.init_process_group(backend='nccl', init_method='env://')
+    device = f'cuda:{torch.distributed.get_rank()}'
+    assert world_size <= torch.distributed.get_world_size()
+    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
+    rank = parallel_state.get_tensor_model_parallel_rank()
+    process_group = parallel_state.get_tensor_model_parallel_group()
+
+    torch.manual_seed(0)
+    batch_size = 1
+    seqlen = 100
+    max_length = 150
+    input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
+                              device=device)
+
+    torch.distributed.barrier()
+
+    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
+    # GPU0 and GPU1 and things would hang
+    torch.cuda.set_device(device)
+
+    pretrained_state_dict = remap_state_dict_hf_falcon(state_dict_from_pretrained(model_name), config)
+
+    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
+    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
+    model.eval()
+
+    print('Without CUDA graph')
+    out = model.generate(
+        input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
+        vocab_size=config.vocab_size, fused_ft_kernel=True,
+        # teacher_outputs=out_hf.sequences,
+        return_dict_in_generate=True, output_scores=True, timing=True
+    )
+
+    # Capture graph outside the timing loop
+    batch_size, seqlen_og = input_ids.shape
+    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
+    print('With CUDA graph')
+    out_cg = model.generate(
+        input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
+        vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
+        # teacher_outputs=out_hf.sequences,
+        return_dict_in_generate=True, output_scores=True, timing=True
+    )
+    del model
+    parallel_state.destroy_model_parallel()
+
+    if rank == 0:
+        model_hf = AutoModelForCausalLM.from_pretrained(
+            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
+        )
+        model_hf.eval()
+        print("HF fp16")
+        torch.cuda.synchronize()
+        start = time.time()
+        with torch.inference_mode():
+            out_hf = model_hf.generate(
+                input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
+                output_scores=True
+            )
+        torch.cuda.synchronize()
+        print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
+        del model_hf
+
+        model_ref = AutoModelForCausalLM.from_pretrained(
+            model_name, device_map="auto", trust_remote_code=True
+        )
+        model_ref.eval()
+        with torch.inference_mode():
+            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
+        del model_ref
+        logits_hf = torch.stack(out_hf.scores, dim=1)
+
+        logits = torch.stack(out.scores, dim=1)
+        logits_cg = torch.stack(out_cg.scores, dim=1)
+
+        hf_error = (logits_hf - logits_ref).abs().max().item()
+        print(f'HF fp16 logits max diff: {hf_error}')
+        print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
+        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
+        print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
+        assert torch.equal(logits_cg, logits)