test_baichuan.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # Copyright (c) 2023, Tri Dao.
  2. import os
  3. import time
  4. from pathlib import Path
  5. import torch
  6. import pytest
  7. from einops import rearrange
  8. from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
  9. from flash_attn.models.gpt import (
  10. GPTLMHeadModel,
  11. combine_state_dicts_tp,
  12. shard_state_dict_tp,
  13. )
  14. from flash_attn.models.baichuan import (
  15. remap_state_dict_hf_baichuan,
  16. baichuan_config_to_gpt2_config,
  17. )
  18. from flash_attn.utils.distributed import all_gather_raw
  19. from flash_attn.utils.pretrained import state_dict_from_pretrained
  20. from flash_attn.utils.generation import update_graph_cache
  21. @pytest.mark.parametrize(
  22. "model_name",
  23. [
  24. "baichuan-inc/Baichuan-7B",
  25. "baichuan-inc/Baichuan-13B-Base",
  26. "baichuan-inc/Baichuan2-7B-Base",
  27. "baichuan-inc/Baichuan2-13B-Base",
  28. ],
  29. )
  30. def test_baichuan_state_dict(model_name):
  31. config = baichuan_config_to_gpt2_config(
  32. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  33. )
  34. pretrained_state_dict = remap_state_dict_hf_baichuan(
  35. state_dict_from_pretrained(model_name), config
  36. )
  37. model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
  38. state_dict = model.state_dict()
  39. assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
  40. assert state_dict.keys() == pretrained_state_dict.keys()
  41. for k in state_dict.keys():
  42. assert state_dict[k].shape == pretrained_state_dict[k].shape
  43. @pytest.mark.parametrize(
  44. "model_name",
  45. [
  46. "baichuan-inc/Baichuan-7B",
  47. "baichuan-inc/Baichuan-13B-Base",
  48. "baichuan-inc/Baichuan2-7B-Base",
  49. "baichuan-inc/Baichuan2-13B-Base",
  50. ],
  51. )
  52. def test_baichuan_optimized(model_name):
  53. """Check that our implementation of Baichuan (with all optimizations enabled) matches the
  54. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  55. forward pass in fp16, when compared to the HF forward pass in fp32.
  56. """
  57. dtype = torch.float16
  58. device = "cuda"
  59. config = baichuan_config_to_gpt2_config(
  60. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  61. )
  62. config.use_flash_attn = True
  63. config.fused_bias_fc = True
  64. config.fused_mlp = False # We don't have fused GatedMLP yet
  65. config.fused_dropout_add_ln = True
  66. config.residual_in_fp32 = True
  67. pretrained_state_dict = remap_state_dict_hf_baichuan(
  68. state_dict_from_pretrained(model_name), config
  69. )
  70. model = GPTLMHeadModel(config, device=device, dtype=dtype)
  71. model.load_state_dict(pretrained_state_dict)
  72. model.eval()
  73. torch.manual_seed(0)
  74. batch_size = 2
  75. max_seqlen = 256
  76. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  77. input_ids = torch.randint(
  78. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  79. )
  80. with torch.no_grad():
  81. out = model.transformer(input_ids)
  82. logits = model(input_ids).logits
  83. del model
  84. # Without device_map, the model is loaded on the CPU, which is very slow
  85. # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
  86. model_ref = AutoModelForCausalLM.from_pretrained(
  87. model_name, device_map="auto", trust_remote_code=True
  88. )
  89. model_ref.eval()
  90. with torch.no_grad():
  91. out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
  92. logits_ref = model_ref(input_ids).logits.to(device=device)
  93. del model_ref
  94. model_hf = AutoModelForCausalLM.from_pretrained(
  95. model_name,
  96. torch_dtype=dtype,
  97. device_map={"": device},
  98. trust_remote_code=True,
  99. )
  100. model_hf.eval()
  101. with torch.no_grad():
  102. out_hf = model_hf.model(input_ids).last_hidden_state
  103. logits_hf = model_hf(input_ids).logits
  104. del model_hf
  105. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  106. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  107. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  108. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  109. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  110. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  111. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  112. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  113. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  114. assert (logits - logits_ref).abs().max().item() < 3 * (
  115. logits_hf - logits_ref
  116. ).abs().max().item()
  117. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
  118. @pytest.mark.parametrize("world_size", [2])
  119. @pytest.mark.parametrize(
  120. "model_name",
  121. [
  122. "baichuan-inc/Baichuan-7B",
  123. "baichuan-inc/Baichuan-13B-Base",
  124. "baichuan-inc/Baichuan2-7B-Base",
  125. "baichuan-inc/Baichuan2-13B-Base",
  126. ],
  127. )
  128. def test_baichuan_parallel_forward(model_name, world_size):
  129. """Check that our implementation of Baichuan (with all optimizations enabled) matches the
  130. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  131. forward pass in fp16, when compared to the HF forward pass in fp32.
  132. """
  133. from apex.transformer import parallel_state
  134. dtype = torch.float16
  135. config = baichuan_config_to_gpt2_config(
  136. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  137. )
  138. config.use_flash_attn = True
  139. config.fused_bias_fc = True
  140. config.fused_mlp = False # We don't have fused GatedMLP yet
  141. config.fused_dropout_add_ln = True
  142. config.residual_in_fp32 = True
  143. if not torch.distributed.is_initialized():
  144. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  145. device = f"cuda:{torch.distributed.get_rank()}"
  146. assert world_size <= torch.distributed.get_world_size()
  147. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  148. rank = parallel_state.get_tensor_model_parallel_rank()
  149. process_group = parallel_state.get_tensor_model_parallel_group()
  150. pretrained_state_dict = remap_state_dict_hf_baichuan(
  151. state_dict_from_pretrained(model_name), config
  152. )
  153. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  154. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  155. model.eval()
  156. torch.manual_seed(0)
  157. batch_size = 2
  158. max_seqlen = 256
  159. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  160. input_ids = torch.randint(
  161. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  162. )
  163. with torch.no_grad():
  164. out = model.transformer(input_ids)
  165. out, _ = all_gather_raw(out, process_group=process_group)
  166. out = rearrange(out, "(b s) d -> b s d", b=batch_size)
  167. logits = model(input_ids).logits
  168. logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
  169. logits, _ = all_gather_raw(logits, process_group)
  170. logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
  171. del model
  172. parallel_state.destroy_model_parallel()
  173. if rank == 0:
  174. # Without device_map, the model is loaded on the CPU, which is very slow
  175. model_ref = AutoModelForCausalLM.from_pretrained(
  176. model_name, device_map="auto", trust_remote_code=True
  177. )
  178. model_ref.eval()
  179. with torch.no_grad():
  180. out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
  181. logits_ref = model_ref(input_ids).logits.to(device=device)
  182. del model_ref
  183. model_hf = AutoModelForCausalLM.from_pretrained(
  184. model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
  185. )
  186. model_hf.eval()
  187. with torch.no_grad():
  188. out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
  189. logits_hf = model_hf(input_ids).logits.to(device=device)
  190. del model_hf
  191. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  192. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  193. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  194. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  195. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  196. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  197. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  198. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  199. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  200. assert (logits - logits_ref).abs().max().item() < 2 * (
  201. logits_hf - logits_ref
  202. ).abs().max().item()
  203. @pytest.mark.parametrize(
  204. "model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]
  205. )
  206. def test_baichuan_generation(model_name):
  207. dtype = torch.float16
  208. device = "cuda"
  209. config = baichuan_config_to_gpt2_config(
  210. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  211. )
  212. config.use_flash_attn = True
  213. config.fused_bias_fc = True
  214. config.fused_mlp = False # We don't have fused GatedMLP yet
  215. config.fused_dropout_add_ln = True
  216. config.residual_in_fp32 = True
  217. tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
  218. eos_token_id = tokenizer.eos_token_id
  219. torch.manual_seed(0)
  220. batch_size = 1
  221. seqlen = 2048
  222. max_length = 2048 + 150
  223. input_ids = torch.randint(
  224. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  225. )
  226. model_hf = AutoModelForCausalLM.from_pretrained(
  227. model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
  228. )
  229. model_hf.eval()
  230. print("HF fp16")
  231. torch.cuda.synchronize()
  232. start = time.time()
  233. out_hf = model_hf.generate(
  234. input_ids=input_ids,
  235. max_length=max_length,
  236. return_dict_in_generate=True,
  237. output_scores=True,
  238. )
  239. torch.cuda.synchronize()
  240. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  241. del model_hf
  242. # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
  243. model_ref = AutoModelForCausalLM.from_pretrained(
  244. model_name, device_map="auto", trust_remote_code=True
  245. )
  246. model_ref.eval()
  247. with torch.no_grad():
  248. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
  249. del model_ref
  250. pretrained_state_dict = remap_state_dict_hf_baichuan(
  251. state_dict_from_pretrained(model_name), config
  252. )
  253. model = GPTLMHeadModel(config, device=device, dtype=dtype)
  254. model.load_state_dict(pretrained_state_dict)
  255. model.eval()
  256. model(input_ids) # Warm up
  257. print("Without CUDA graph")
  258. torch.cuda.synchronize()
  259. start = time.time()
  260. out = model.generate(
  261. input_ids=input_ids,
  262. max_length=max_length,
  263. eos_token_id=eos_token_id,
  264. return_dict_in_generate=True,
  265. output_scores=True,
  266. enable_timing=True,
  267. teacher_outputs=out_hf.sequences,
  268. )
  269. torch.cuda.synchronize()
  270. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  271. # Capture graph outside the timing loop
  272. batch_size, seqlen_og = input_ids.shape
  273. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  274. print("With CUDA graph")
  275. torch.cuda.synchronize()
  276. start = time.time()
  277. out_cg = model.generate(
  278. input_ids=input_ids,
  279. max_length=max_length,
  280. cg=True,
  281. return_dict_in_generate=True,
  282. output_scores=True,
  283. enable_timing=True,
  284. teacher_outputs=out_hf.sequences,
  285. )
  286. torch.cuda.synchronize()
  287. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  288. with torch.no_grad():
  289. logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  290. logits_hf = torch.stack(out_hf.scores, dim=1)
  291. logits = torch.stack(out.scores, dim=1)
  292. logits_cg = torch.stack(out_cg.scores, dim=1)
  293. del model
  294. hf_error = (logits_hf - logits_ref).abs().max().item()
  295. print(f"HF fp16 logits max diff: {hf_error}")
  296. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  297. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  298. assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
  299. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  300. assert torch.equal(logits_cg, logits)
  301. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
  302. @pytest.mark.parametrize("world_size", [2])
  303. @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
  304. def test_baichuan_parallel_generation(model_name, world_size):
  305. """Check that our implementation matches the HF implementation:
  306. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  307. the HF scores in fp32.
  308. """
  309. from apex.transformer import parallel_state
  310. dtype = torch.float16
  311. config = baichuan_config_to_gpt2_config(
  312. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  313. )
  314. config.use_flash_attn = True
  315. config.fused_bias_fc = True
  316. config.fused_mlp = False # We don't have fused GatedMLP yet
  317. config.fused_dropout_add_ln = False
  318. config.residual_in_fp32 = True
  319. config.pad_vocab_size_multiple = 8 * world_size
  320. config.sequence_parallel = False # Need to set this to False for generation
  321. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
  322. if not torch.distributed.is_initialized():
  323. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  324. device = f"cuda:{torch.distributed.get_rank()}"
  325. assert world_size <= torch.distributed.get_world_size()
  326. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  327. rank = parallel_state.get_tensor_model_parallel_rank()
  328. process_group = parallel_state.get_tensor_model_parallel_group()
  329. torch.manual_seed(0)
  330. batch_size = 1
  331. seqlen = 100
  332. max_length = 150
  333. input_ids = torch.randint(
  334. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  335. )
  336. # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
  337. # GPU0 and GPU1 and things would hang
  338. torch.cuda.set_device(device)
  339. pretrained_state_dict = remap_state_dict_hf_baichuan(
  340. state_dict_from_pretrained(model_name), config
  341. )
  342. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  343. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  344. model.eval()
  345. print("Without CUDA graph")
  346. out = model.generate(
  347. input_ids=input_ids,
  348. max_length=max_length,
  349. tensor_parallel=world_size,
  350. vocab_size=config.vocab_size,
  351. # teacher_outputs=out_hf.sequences,
  352. return_dict_in_generate=True,
  353. output_scores=True,
  354. enable_timing=True,
  355. )
  356. # Capture graph outside the timing loop
  357. batch_size, seqlen_og = input_ids.shape
  358. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  359. print("With CUDA graph")
  360. out_cg = model.generate(
  361. input_ids=input_ids,
  362. max_length=max_length,
  363. tensor_parallel=world_size,
  364. vocab_size=config.vocab_size,
  365. cg=True,
  366. # teacher_outputs=out_hf.sequences,
  367. return_dict_in_generate=True,
  368. output_scores=True,
  369. enable_timing=True,
  370. )
  371. del model
  372. parallel_state.destroy_model_parallel()
  373. if rank == 0:
  374. # Without device_map, the model is loaded on the CPU, which is very slow
  375. model_hf = AutoModelForCausalLM.from_pretrained(
  376. model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
  377. )
  378. model_hf.eval()
  379. print("HF fp16")
  380. torch.cuda.synchronize()
  381. start = time.time()
  382. with torch.inference_mode():
  383. out_hf = model_hf.generate(
  384. input_ids=input_ids,
  385. max_length=max_length,
  386. return_dict_in_generate=True,
  387. output_scores=True,
  388. )
  389. torch.cuda.synchronize()
  390. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  391. del model_hf
  392. model_ref = AutoModelForCausalLM.from_pretrained(
  393. model_name, device_map="auto", trust_remote_code=True
  394. )
  395. model_ref.eval()
  396. with torch.inference_mode():
  397. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  398. del model_ref
  399. logits_hf = torch.stack(out_hf.scores, dim=1)
  400. logits = torch.stack(out.scores, dim=1)
  401. logits_cg = torch.stack(out_cg.scores, dim=1)
  402. hf_error = (logits_hf - logits_ref).abs().max().item()
  403. print(f"HF fp16 logits max diff: {hf_error}")
  404. print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
  405. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
  406. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  407. assert torch.equal(logits_cg, logits)