123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- import re
- import pytest
- import torch
- from einops import rearrange
- from flash_attn.models.gpt import (
- GPTLMHeadModel,
- remap_state_dict_hf_gpt2,
- shard_state_dict_tp,
- combine_state_dicts_tp,
- )
- from flash_attn.utils.generation import InferenceParams
- from flash_attn.utils.pretrained import state_dict_from_pretrained
- from transformers import GPT2Config, GPT2Tokenizer
- from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
- @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
- # @pytest.mark.parametrize('model_name', ["gpt2"])
- def test_gpt2_state_dict(model_name):
- config = GPT2Config.from_pretrained(model_name)
- pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
- model = GPTLMHeadModel(config)
- state_dict = model.state_dict()
- assert state_dict.keys() == pretrained_state_dict.keys()
- for k in state_dict.keys():
- assert state_dict[k].shape == pretrained_state_dict[k].shape
- @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
- # @pytest.mark.parametrize('model_name', ["gpt2"])
- def test_gpt2_non_optimized(model_name):
- """Check that our implementation of GPT2 (without any optimizations enabled) matches the
- HF implementation: the output of our forward pass in fp16 should be around the same as the HF
- forward pass in fp16, when compared to the HF forward pass in fp32.
- """
- dtype = torch.float16
- config = GPT2Config.from_pretrained(model_name)
- model = GPTLMHeadModel.from_pretrained(model_name, config)
- model = model.cuda().to(dtype=dtype)
- model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
- model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
- model.eval()
- model_ref.eval()
- model_hf.eval()
- torch.manual_seed(0)
- batch_size = 4
- max_seqlen = 512
- seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
- input_ids = torch.randint(
- 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
- )
- out = model.transformer(input_ids)
- out_hf = model_hf.transformer(input_ids).last_hidden_state
- out_ref = model_ref.transformer(input_ids).last_hidden_state
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
- print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
- print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
- assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
- logits = model(input_ids).logits
- logits_hf = model_hf(input_ids).logits
- logits_ref = model_ref(input_ids).logits
- print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
- print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
- print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
- print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
- assert (logits - logits_ref).abs().max().item() < 3 * (
- logits_hf - logits_ref
- ).abs().max().item()
- @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
- # @pytest.mark.parametrize('model_name', ["gpt2"])
- def test_gpt2_optimized(model_name):
- """Check that our implementation of GPT2 (with all optimizations enabled) matches the
- HF implementation: the output of our forward pass in fp16 should be around the same as the HF
- forward pass in fp16, when compared to the HF forward pass in fp32.
- """
- dtype = torch.float16
- config = GPT2Config.from_pretrained(model_name)
- vocab_size_og = config.vocab_size
- config.use_flash_attn = True
- config.fused_bias_fc = True
- config.fused_mlp = True
- config.fused_dropout_add_ln = True
- config.residual_in_fp32 = True
- config.pad_vocab_size_multiple = 8
- model = GPTLMHeadModel.from_pretrained(model_name, config)
- model = model.cuda().to(dtype=dtype)
- model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
- model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
- model.eval()
- model_ref.eval()
- model_hf.eval()
- torch.manual_seed(0)
- batch_size = 4
- max_seqlen = 512
- seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
- input_ids = torch.randint(
- 0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
- )
- out = model.transformer(input_ids)
- out_hf = model_hf.transformer(input_ids).last_hidden_state
- out_ref = model_ref.transformer(input_ids).last_hidden_state
- print(f"Output max diff: {(out - out_ref).abs().max().item()}")
- print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
- print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
- print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
- assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
- logits = model(input_ids).logits[..., :vocab_size_og]
- logits_hf = model_hf(input_ids).logits
- logits_ref = model_ref(input_ids).logits
- print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
- print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
- print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
- print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
- assert (logits - logits_ref).abs().max().item() < 3 * (
- logits_hf - logits_ref
- ).abs().max().item()
- @pytest.mark.parametrize("optimized", [False, True])
- # @pytest.mark.parametrize('optimized', [True])
- @pytest.mark.parametrize("rotary", [False, True])
- # @pytest.mark.parametrize('rotary', [False])
- @pytest.mark.parametrize("model_name", ["gpt2"])
- def test_gpt2_generation(model_name, rotary, optimized):
- """Check that our implementation of GPT2 generation matches the HF implementation:
- the scores in fp16 should be around the same as the HF scores in fp16, when compared to
- the HF scores in fp32.
- """
- dtype = torch.float16
- device = "cuda"
- rtol, atol = 3e-3, 3e-1
- config = GPT2Config.from_pretrained(model_name)
- if rotary:
- config.n_positions = 0
- config.rotary_emb_fraction = 0.5
- config.rotary_emb_base = 24000
- config.residual_in_fp32 = True
- if optimized:
- config.use_flash_attn = True
- config.fused_bias_fc = True
- config.fused_mlp = True
- config.fused_dropout_add_ln = True
- # if not rotary, we load the weight from HF but ignore the position embeddings.
- # The model would be nonsense but it doesn't matter for the test.
- model = GPTLMHeadModel.from_pretrained(
- model_name, config, strict=not rotary, device=device, dtype=dtype
- )
- model.eval()
- if not rotary:
- model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
- model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
- device=device
- )
- model_ref.eval()
- model_hf.eval()
- torch.manual_seed(0)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
- device=device
- )
- max_length = 25
- # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
- # max_length = input_ids.shape[1] + 40
- # Slow generation for reference
- sequences = []
- scores = []
- cur_input_ids = input_ids
- with torch.inference_mode():
- scores.append(model(cur_input_ids).logits[:, -1])
- sequences.append(scores[-1].argmax(dim=-1))
- for _ in range(input_ids.shape[1] + 1, max_length):
- cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
- scores.append(model(cur_input_ids).logits[:, -1])
- sequences.append(scores[-1].argmax(dim=-1))
- sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
- scores = tuple(scores)
- out = model.generate(
- input_ids=input_ids,
- max_length=max_length,
- return_dict_in_generate=True,
- output_scores=True,
- enable_timing=True,
- )
- print(out.sequences)
- print(tokenizer.batch_decode(out.sequences.tolist()))
- if getattr(config, "use_flash_attn", False):
- out_cg = model.generate(
- input_ids=input_ids,
- max_length=max_length,
- cg=True,
- return_dict_in_generate=True,
- output_scores=True,
- enable_timing=True,
- )
- print(out_cg.sequences)
- assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
- if not rotary:
- out_hf = model_hf.generate(
- input_ids=input_ids,
- max_length=max_length,
- return_dict_in_generate=True,
- output_scores=True,
- )
- out_ref = model_ref.generate(
- input_ids=input_ids,
- max_length=max_length,
- return_dict_in_generate=True,
- output_scores=True,
- )
- print(
- f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
- )
- print(
- f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
- )
- print(
- f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
- )
- print(
- f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
- )
- print(tokenizer.batch_decode(out_ref.sequences.tolist()))
- assert torch.all(out.sequences == sequences)
- assert torch.allclose(
- torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
- )
- if not rotary:
- assert torch.all(out.sequences == out_ref.sequences)
- assert torch.all(out.sequences == out_hf.sequences)
- assert (
- torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
- ).abs().max().item() < 3 * (
- torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
- ).abs().max().item()
- def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
- out = model.generate(
- input_ids=input_ids,
- max_length=max_length,
- teacher_outputs=teacher_outputs,
- return_dict_in_generate=True,
- output_scores=True,
- enable_timing=True,
- **kwargs,
- )
- return torch.stack(out.scores, dim=1)
- @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
- # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
- @pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
- # @pytest.mark.parametrize('rotary', [None])
- @pytest.mark.parametrize("model_name", ["gpt2"])
- def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
- """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
- dtype = torch.float16
- device = "cuda"
- rtol, atol = 3e-3, 3e-1
- config = GPT2Config.from_pretrained(model_name)
- config.n_positions = 16 * 1024
- assert seqlen <= maxlen <= config.n_positions
- if rotary is not None:
- config.n_positions = 0
- config.rotary_emb_dim = 32
- config.rotary_emb_interleaved = rotary == "interleaved"
- config.residual_in_fp32 = True
- config.use_flash_attn = True
- config.fused_bias_fc = True
- config.fused_mlp = True
- config.fused_dropout_add_ln = True
- model = GPTLMHeadModel(config, device=device, dtype=dtype)
- model.eval()
- torch.manual_seed(0)
- batch_size = 1
- input_ids = torch.randint(
- 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
- )
- teacher_outputs = torch.randint(
- 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
- )
- logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
- logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
- assert torch.equal(logits, logits_cg)
- # Try increasing batch size and seqlen, then decrease them to see if it's still correct
- batch_size = 3
- maxlen += 30
- input_ids = torch.randint(
- 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
- )
- teacher_outputs = torch.randint(
- 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
- )
- logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
- logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
- assert torch.equal(logits, logits_cg)
- batch_size = 2
- maxlen -= 35
- input_ids = torch.randint(
- 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
- )
- teacher_outputs = torch.randint(
- 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
- )
- logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
- logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
- assert torch.equal(logits, logits_cg)
- @pytest.mark.parametrize("optimized", [False, True])
- # @pytest.mark.parametrize("optimized", [False])
- @pytest.mark.parametrize("model_name", ["gpt2"])
- def test_gpt2_multiple_token_generation(model_name, optimized):
- """Generation when we pass in multiple tokens at a time, not just one."""
- dtype = torch.float16
- device = "cuda"
- rtol, atol = 3e-3, 3e-1
- config = GPT2Config.from_pretrained(model_name)
- config.residual_in_fp32 = True
- if optimized:
- config.use_flash_attn = True
- config.fused_bias_fc = True
- config.fused_mlp = True
- config.fused_dropout_add_ln = True
- model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
- model.eval()
- torch.manual_seed(0)
- input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)
- # Reference logits
- logits_ref = model(input_ids).logits
- # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
- inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
- logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
- inference_params.seqlen_offset += 10
- position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
- logits_1014 = model(
- input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
- ).logits
- inference_params.seqlen_offset += 4
- position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
- logits_1420 = model(
- input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
- ).logits
- logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
- print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
- print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
- assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
- @pytest.mark.parametrize("cg", [False, True])
- # @pytest.mark.parametrize("cg", [True])
- @pytest.mark.parametrize("optimized", [False, True])
- # @pytest.mark.parametrize("optimized", [True])
- # @pytest.mark.parametrize("model_name", ["gpt2-medium"])
- @pytest.mark.parametrize("model_name", ["gpt2-xl"])
- def test_gpt2_speculative_decoding(model_name, optimized, cg):
- if cg and not optimized:
- pytest.skip() # CG requires use_flash_attn
- dtype = torch.float16
- device = "cuda"
- rtol, atol = 3e-3, 3e-1
- config = GPT2Config.from_pretrained(model_name)
- config.residual_in_fp32 = True
- if optimized:
- config.use_flash_attn = True
- config.fused_bias_fc = True
- config.fused_mlp = True
- config.fused_dropout_add_ln = True
- config_draft = GPT2Config.from_pretrained("gpt2")
- config_draft.residual_in_fp32 = True
- if optimized:
- config_draft.use_flash_attn = True
- config_draft.fused_bias_fc = True
- config_draft.fused_mlp = True
- config_draft.fused_dropout_add_ln = True
- model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
- model.eval()
- model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype)
- model_draft.eval()
- torch.manual_seed(0)
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
- input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
- device=device
- )
- max_length = 100
- from flash_attn.utils.generation import decode_speculative
- torch.manual_seed(42)
- print(f"Speculative decoding, {optimized = }")
- out = decode_speculative(
- input_ids,
- model,
- model_draft,
- max_length=max_length,
- top_k=5,
- cg=cg,
- speculative_lookahead=4,
- enable_timing=True,
- # debug=True,
- )
- print(tokenizer.batch_decode(out.sequences))
- print(f"Without speculative decoding, {cg = }")
- out_og = model.generate(
- input_ids,
- max_length=max_length,
- top_k=5,
- cg=cg,
- enable_timing=True,
- return_dict_in_generate=True,
- )
- print(tokenizer.batch_decode(out_og.sequences))
- @pytest.mark.parametrize(
- "n_heads_q_kv",
- [
- (8, 8), # Regular attention
- (8, 4), # GQA
- (8, 2), # MQA
- ],
- )
- def test_gpt2_shard_unshard(n_heads_q_kv):
- world_size = 2
- config = GPT2Config.from_pretrained("gpt2")
- config.vocab_size = 1024
- config.n_head, config.n_head_kv = n_heads_q_kv
- model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
- state_dict = model.state_dict()
- shards = [
- # NOTE: Shallow copy as `state_dict` is modified in-place
- shard_state_dict_tp(dict(state_dict), config, world_size, rank)
- for rank in range(world_size)
- ]
- state_dict2 = combine_state_dicts_tp(shards, config)
- assert state_dict2.keys() == state_dict.keys()
- for k in state_dict.keys():
- ref = state_dict[k]
- new = state_dict[k]
- assert torch.allclose(ref, new, atol=0.0, rtol=0.0)
|