test_falcon.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright (c) 2023, Tri Dao.
  2. import os
  3. import time
  4. from pathlib import Path
  5. current_dir = Path(__file__).parent.absolute()
  6. import pytest
  7. import torch
  8. from einops import rearrange
  9. from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
  10. from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
  11. from flash_attn.utils.distributed import all_gather_raw
  12. from flash_attn.utils.generation import update_graph_cache
  13. from flash_attn.utils.pretrained import state_dict_from_pretrained
  14. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  15. @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
  16. def test_falcon_state_dict(model_name):
  17. config = falcon_config_to_gpt2_config(
  18. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  19. )
  20. pretrained_state_dict = remap_state_dict_hf_falcon(
  21. state_dict_from_pretrained(model_name), config
  22. )
  23. model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
  24. state_dict = model.state_dict()
  25. assert state_dict.keys() == pretrained_state_dict.keys()
  26. for k in state_dict.keys():
  27. assert state_dict[k].shape == pretrained_state_dict[k].shape
  28. @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
  29. def test_falcon_optimized(model_name):
  30. """Check that our implementation (with all optimizations enabled) matches the
  31. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  32. forward pass in fp16, when compared to the HF forward pass in fp32.
  33. """
  34. dtype = torch.float16
  35. device = "cuda"
  36. config = falcon_config_to_gpt2_config(
  37. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  38. )
  39. config.use_flash_attn = True
  40. config.fused_bias_fc = True
  41. config.fused_mlp = False # We don't have fused MLP for "gelu" activation
  42. config.fused_dropout_add_ln = True
  43. config.residual_in_fp32 = True
  44. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  45. model.eval()
  46. torch.manual_seed(0)
  47. batch_size = 2
  48. max_seqlen = 256
  49. input_ids = torch.randint(
  50. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  51. )
  52. with torch.no_grad():
  53. out = model.transformer(input_ids)
  54. logits = model(input_ids).logits
  55. del model
  56. # Without device_map, the model is loaded on the CPU, which is very slow
  57. model_ref = AutoModelForCausalLM.from_pretrained(
  58. model_name, device_map={"": device}, trust_remote_code=True
  59. )
  60. model_ref.eval()
  61. with torch.no_grad():
  62. out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
  63. logits_ref = model_ref(input_ids).logits.to(device=device)
  64. del model_ref
  65. model_hf = AutoModelForCausalLM.from_pretrained(
  66. model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
  67. )
  68. model_hf.eval()
  69. out_hf = model_hf.transformer(input_ids).last_hidden_state
  70. logits_hf = model_hf(input_ids).logits
  71. del model_hf
  72. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  73. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  74. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  75. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  76. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  77. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  78. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  79. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  80. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  81. assert (logits - logits_ref).abs().max().item() < 3 * (
  82. logits_hf - logits_ref
  83. ).abs().max().item()
  84. # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
  85. # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
  86. # memory to run the model in fp32.
  87. @pytest.mark.parametrize("world_size", [4])
  88. @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
  89. def test_falcon_parallel_forward(model_name, world_size):
  90. from apex.transformer import parallel_state
  91. dtype = torch.float16
  92. config = falcon_config_to_gpt2_config(
  93. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  94. )
  95. config.use_flash_attn = False
  96. config.fused_bias_fc = True
  97. config.fused_mlp = False # We don't have fused MLP for "gelu" activation
  98. config.fused_dropout_add_ln = False
  99. config.residual_in_fp32 = True
  100. if not torch.distributed.is_initialized():
  101. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  102. device = f"cuda:{torch.distributed.get_rank()}"
  103. assert world_size <= torch.distributed.get_world_size()
  104. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  105. rank = parallel_state.get_tensor_model_parallel_rank()
  106. process_group = parallel_state.get_tensor_model_parallel_group()
  107. pretrained_state_dict = remap_state_dict_hf_falcon(
  108. state_dict_from_pretrained(model_name), config
  109. )
  110. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  111. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  112. model.eval()
  113. torch.manual_seed(0)
  114. batch_size = 2
  115. max_seqlen = 256
  116. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  117. input_ids = torch.randint(
  118. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  119. )
  120. with torch.no_grad():
  121. out = model.transformer(input_ids)
  122. out, _ = all_gather_raw(out, process_group=process_group)
  123. out = rearrange(out, "(b s) d -> b s d", b=batch_size)
  124. logits = model(input_ids).logits
  125. logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
  126. logits, _ = all_gather_raw(logits, process_group)
  127. logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
  128. del model
  129. parallel_state.destroy_model_parallel()
  130. if rank == 0:
  131. model_hf = AutoModelForCausalLM.from_pretrained(
  132. model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
  133. )
  134. model_hf.eval()
  135. out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
  136. logits_hf = model_hf(input_ids).logits.to(device=device)
  137. del model_hf
  138. # Without device_map, the model is loaded on the CPU, which is very slow
  139. model_ref = AutoModelForCausalLM.from_pretrained(
  140. model_name, device_map="auto", trust_remote_code=True
  141. )
  142. model_ref.eval()
  143. with torch.no_grad():
  144. out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
  145. logits_ref = model_ref(input_ids).logits.to(device=device)
  146. del model_ref
  147. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  148. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  149. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  150. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  151. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  152. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  153. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  154. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  155. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  156. assert (logits - logits_ref).abs().max().item() < 2 * (
  157. logits_hf - logits_ref
  158. ).abs().max().item()
  159. @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
  160. def test_falcon_generation(model_name):
  161. """Check that our implementation (with all optimizations enabled) matches the
  162. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  163. forward pass in fp16, when compared to the HF forward pass in fp32.
  164. """
  165. dtype = torch.float16
  166. device = "cuda"
  167. config = falcon_config_to_gpt2_config(
  168. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  169. )
  170. config.use_flash_attn = True
  171. config.fused_bias_fc = True
  172. config.fused_mlp = False # We don't have fused MLP for "gelu" activation
  173. config.fused_dropout_add_ln = True
  174. config.residual_in_fp32 = True
  175. tokenizer = AutoTokenizer.from_pretrained(model_name)
  176. eos_token_id = tokenizer.eos_token_id
  177. torch.manual_seed(0)
  178. batch_size = 1
  179. seqlen = 100
  180. max_length = 150
  181. input_ids = torch.randint(
  182. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  183. )
  184. model_hf = AutoModelForCausalLM.from_pretrained(
  185. model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
  186. )
  187. model_hf.eval()
  188. print("HF fp16")
  189. torch.cuda.synchronize()
  190. start = time.time()
  191. out_hf = model_hf.generate(
  192. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  193. )
  194. torch.cuda.synchronize()
  195. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  196. del model_hf
  197. model_ref = AutoModelForCausalLM.from_pretrained(
  198. model_name, device_map={"": device}, trust_remote_code=True
  199. )
  200. model_ref.eval()
  201. with torch.no_grad():
  202. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  203. del model_ref
  204. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
  205. model.eval()
  206. print("Without CUDA graph")
  207. torch.cuda.synchronize()
  208. start = time.time()
  209. out = model.generate(
  210. input_ids=input_ids,
  211. max_length=max_length,
  212. eos_token_id=eos_token_id,
  213. return_dict_in_generate=True,
  214. output_scores=True,
  215. enable_timing=True,
  216. teacher_outputs=out_hf.sequences,
  217. )
  218. torch.cuda.synchronize()
  219. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  220. # Capture graph outside the timing loop
  221. batch_size, seqlen_og = input_ids.shape
  222. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  223. print("With CUDA graph")
  224. torch.cuda.synchronize()
  225. start = time.time()
  226. out_cg = model.generate(
  227. input_ids=input_ids,
  228. max_length=max_length,
  229. cg=True,
  230. return_dict_in_generate=True,
  231. output_scores=True,
  232. enable_timing=True,
  233. teacher_outputs=out_hf.sequences,
  234. )
  235. torch.cuda.synchronize()
  236. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  237. with torch.no_grad():
  238. logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  239. logits_hf = torch.stack(out_hf.scores, dim=1)
  240. logits = torch.stack(out.scores, dim=1)
  241. logits_cg = torch.stack(out_cg.scores, dim=1)
  242. del model
  243. hf_error = (logits_hf - logits_ref).abs().max().item()
  244. assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
  245. print(f"HF fp16 logits max diff: {hf_error}")
  246. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  247. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  248. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  249. assert torch.equal(logits_cg, logits)
  250. # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
  251. # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
  252. # memory to run the model in fp32.
  253. @pytest.mark.parametrize("world_size", [4])
  254. @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
  255. def test_falcon_parallel_generation(model_name, world_size):
  256. """Check that our implementation matches the HF implementation:
  257. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  258. the HF scores in fp32.
  259. """
  260. from apex.transformer import parallel_state
  261. dtype = torch.float16
  262. config = falcon_config_to_gpt2_config(
  263. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  264. )
  265. config.use_flash_attn = False
  266. config.fused_bias_fc = True
  267. config.fused_mlp = False # We don't have fused MLP for "gelu" activation
  268. config.fused_dropout_add_ln = False
  269. config.residual_in_fp32 = True
  270. config.pad_vocab_size_multiple = 8 * world_size
  271. config.sequence_parallel = False # Need to set this to False for generation
  272. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
  273. if not torch.distributed.is_initialized():
  274. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  275. device = f"cuda:{torch.distributed.get_rank()}"
  276. assert world_size <= torch.distributed.get_world_size()
  277. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  278. rank = parallel_state.get_tensor_model_parallel_rank()
  279. process_group = parallel_state.get_tensor_model_parallel_group()
  280. torch.manual_seed(0)
  281. batch_size = 1
  282. seqlen = 100
  283. max_length = 150
  284. input_ids = torch.randint(
  285. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  286. )
  287. # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
  288. # GPU0 and GPU1 and things would hang
  289. torch.cuda.set_device(device)
  290. pretrained_state_dict = remap_state_dict_hf_falcon(
  291. state_dict_from_pretrained(model_name), config
  292. )
  293. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  294. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  295. model.eval()
  296. print("Without CUDA graph")
  297. out = model.generate(
  298. input_ids=input_ids,
  299. max_length=max_length,
  300. tensor_parallel=world_size,
  301. vocab_size=config.vocab_size,
  302. # teacher_outputs=out_hf.sequences,
  303. return_dict_in_generate=True,
  304. output_scores=True,
  305. enable_timing=True,
  306. )
  307. # Capture graph outside the timing loop
  308. batch_size, seqlen_og = input_ids.shape
  309. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  310. print("With CUDA graph")
  311. out_cg = model.generate(
  312. input_ids=input_ids,
  313. max_length=max_length,
  314. tensor_parallel=world_size,
  315. vocab_size=config.vocab_size,
  316. cg=True,
  317. # teacher_outputs=out_hf.sequences,
  318. return_dict_in_generate=True,
  319. output_scores=True,
  320. enable_timing=True,
  321. )
  322. del model
  323. parallel_state.destroy_model_parallel()
  324. if rank == 0:
  325. model_hf = AutoModelForCausalLM.from_pretrained(
  326. model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
  327. )
  328. model_hf.eval()
  329. print("HF fp16")
  330. torch.cuda.synchronize()
  331. start = time.time()
  332. with torch.inference_mode():
  333. out_hf = model_hf.generate(
  334. input_ids=input_ids,
  335. max_length=max_length,
  336. return_dict_in_generate=True,
  337. output_scores=True,
  338. )
  339. torch.cuda.synchronize()
  340. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  341. del model_hf
  342. model_ref = AutoModelForCausalLM.from_pretrained(
  343. model_name, device_map="auto", trust_remote_code=True
  344. )
  345. model_ref.eval()
  346. with torch.inference_mode():
  347. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  348. del model_ref
  349. logits_hf = torch.stack(out_hf.scores, dim=1)
  350. logits = torch.stack(out.scores, dim=1)
  351. logits_cg = torch.stack(out_cg.scores, dim=1)
  352. hf_error = (logits_hf - logits_ref).abs().max().item()
  353. print(f"HF fp16 logits max diff: {hf_error}")
  354. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  355. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  356. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  357. assert torch.equal(logits_cg, logits)