1
0

test_llama.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # Copyright (c) 2023, Tri Dao.
  2. # To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
  3. # https://github.com/huggingface/transformers/pull/21955
  4. # python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
  5. # and repeat for 13B, 30B, 65B
  6. import os
  7. import time
  8. from pathlib import Path
  9. current_dir = Path(__file__).parent.absolute()
  10. import shutil
  11. import pytest
  12. import torch
  13. from einops import rearrange
  14. from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
  15. from flash_attn.models.llama import (
  16. config_from_checkpoint,
  17. inv_remap_state_dict_hf_llama,
  18. llama_config_to_gpt2_config,
  19. remap_state_dict_hf_llama,
  20. remap_state_dict_meta_llama,
  21. state_dicts_from_checkpoint,
  22. )
  23. from flash_attn.utils.distributed import all_gather_raw
  24. from flash_attn.utils.generation import update_graph_cache
  25. from flash_attn.utils.pretrained import state_dict_from_pretrained
  26. from transformers import LlamaConfig, LlamaTokenizer
  27. from transformers.models.llama.modeling_llama import LlamaForCausalLM
  28. from transformers import AutoConfig
  29. def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
  30. if checkpoint_format == "meta":
  31. ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
  32. pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
  33. pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
  34. else:
  35. pretrained_state_dict = state_dict_from_pretrained(
  36. Path(checkpoint_path) / f"{model_name}-hf"
  37. )
  38. pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
  39. return pretrained_state_dict
  40. @pytest.mark.parametrize("model_name", ["7B"])
  41. def test_llama_state_dict(model_name):
  42. checkpoint_path = (
  43. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  44. )
  45. config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
  46. ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
  47. pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
  48. model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
  49. state_dict = model.state_dict()
  50. assert state_dict.keys() == pretrained_state_dict.keys()
  51. for k in state_dict.keys():
  52. assert state_dict[k].shape == pretrained_state_dict[k].shape
  53. # TinyLlama-1.1B is to test MQA
  54. @pytest.mark.parametrize(
  55. "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
  56. )
  57. def test_inv_remap_state_dict_hf_llama(model_name):
  58. config = llama_config_to_gpt2_config(
  59. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  60. )
  61. state_dict = state_dict_from_pretrained(model_name)
  62. # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
  63. state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
  64. pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
  65. state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
  66. assert set(state_dict_recover.keys()) == set(state_dict.keys())
  67. for key in state_dict_recover.keys():
  68. torch.testing.assert_close(state_dict_recover[key], state_dict[key])
  69. # TinyLlama-1.1B is to test MQA
  70. @pytest.mark.parametrize(
  71. "model_name",
  72. [
  73. "7B", # Llama 1
  74. "13B", # Llama 1
  75. "meta-llama/Llama-2-13b-hf",
  76. "codellama/CodeLlama-7b-hf",
  77. "codellama/CodeLlama-13b-hf",
  78. "codellama/CodeLlama-34b-hf",
  79. "PY007/TinyLlama-1.1B-step-50K-105b",
  80. ],
  81. )
  82. def test_llama_optimized(model_name):
  83. """Check that our implementation of LLaMa (with all optimizations enabled) matches the
  84. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  85. forward pass in fp16, when compared to the HF forward pass in fp32.
  86. """
  87. checkpoint_path = (
  88. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  89. )
  90. dtype = torch.float16
  91. device = "cuda"
  92. if "/" in model_name: # Download from HF
  93. config = llama_config_to_gpt2_config(
  94. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  95. )
  96. else:
  97. config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
  98. config = llama_config_to_gpt2_config(config)
  99. config.use_flash_attn = True
  100. config.fused_bias_fc = True
  101. config.fused_mlp = False # We don't have fused GatedMLP yet
  102. config.fused_dropout_add_ln = True
  103. config.residual_in_fp32 = True
  104. if "/" in model_name: # Download from HF
  105. pretrained_state_dict = remap_state_dict_hf_llama(
  106. state_dict_from_pretrained(model_name), config
  107. )
  108. else:
  109. pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
  110. checkpoint_path, model_name, config, checkpoint_format="meta"
  111. )
  112. model = GPTLMHeadModel(config, device=device, dtype=dtype)
  113. model.load_state_dict(pretrained_state_dict)
  114. model.eval()
  115. torch.manual_seed(0)
  116. batch_size = 2
  117. max_seqlen = 256
  118. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  119. input_ids = torch.randint(
  120. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  121. )
  122. with torch.no_grad():
  123. out = model.transformer(input_ids)
  124. logits = model(input_ids).logits
  125. del model
  126. # Without device_map, the model is loaded on the CPU, which is very slow
  127. # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
  128. model_ref = LlamaForCausalLM.from_pretrained(
  129. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  130. device_map="auto",
  131. )
  132. model_ref.eval()
  133. with torch.no_grad():
  134. out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
  135. logits_ref = model_ref(input_ids).logits.to(device=device)
  136. del model_ref
  137. model_hf = LlamaForCausalLM.from_pretrained(
  138. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  139. torch_dtype=dtype,
  140. device_map={"": device},
  141. )
  142. model_hf.eval()
  143. with torch.no_grad():
  144. out_hf = model_hf.model(input_ids).last_hidden_state
  145. logits_hf = model_hf(input_ids).logits
  146. del model_hf
  147. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  148. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  149. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  150. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  151. assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
  152. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  153. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  154. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  155. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  156. assert (logits - logits_ref).abs().max().item() < 3 * (
  157. logits_hf - logits_ref
  158. ).abs().max().item()
  159. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
  160. @pytest.mark.parametrize("world_size", [2])
  161. @pytest.mark.parametrize(
  162. "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
  163. )
  164. def test_llama_parallel(model_name, world_size):
  165. """Check that our implementation of LLaMa (with all optimizations enabled) matches the
  166. HF implementation: the output of our forward pass in fp16 should be around the same as the HF
  167. forward pass in fp16, when compared to the HF forward pass in fp32.
  168. """
  169. from apex.transformer import parallel_state
  170. checkpoint_path = (
  171. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  172. )
  173. dtype = torch.float16
  174. if "/" in model_name: # Download from HF
  175. config = llama_config_to_gpt2_config(
  176. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  177. )
  178. else:
  179. config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
  180. config = llama_config_to_gpt2_config(config)
  181. config.use_flash_attn = True
  182. config.fused_bias_fc = True
  183. config.fused_mlp = False # We don't have fused GatedMLP yet
  184. config.fused_dropout_add_ln = True
  185. config.residual_in_fp32 = True
  186. if not torch.distributed.is_initialized():
  187. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  188. device = f"cuda:{torch.distributed.get_rank()}"
  189. assert world_size <= torch.distributed.get_world_size()
  190. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  191. rank = parallel_state.get_tensor_model_parallel_rank()
  192. process_group = parallel_state.get_tensor_model_parallel_group()
  193. if "/" in model_name: # Download from HF
  194. pretrained_state_dict = remap_state_dict_hf_llama(
  195. state_dict_from_pretrained(model_name), config
  196. )
  197. else:
  198. pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
  199. checkpoint_path, model_name, config, checkpoint_format="meta"
  200. )
  201. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  202. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  203. model.eval()
  204. torch.manual_seed(0)
  205. batch_size = 2
  206. max_seqlen = 256
  207. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  208. input_ids = torch.randint(
  209. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  210. )
  211. with torch.no_grad():
  212. out = model.transformer(input_ids)
  213. out, _ = all_gather_raw(out, process_group=process_group)
  214. out = rearrange(out, "(b s) d -> b s d", b=batch_size)
  215. logits = model(input_ids).logits
  216. logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
  217. logits, _ = all_gather_raw(logits, process_group)
  218. logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
  219. del model
  220. if rank == 0:
  221. # Without device_map, the model is loaded on the CPU, which is very slow
  222. model_ref = LlamaForCausalLM.from_pretrained(
  223. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  224. device_map="auto",
  225. )
  226. model_ref.eval()
  227. with torch.no_grad():
  228. out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
  229. logits_ref = model_ref(input_ids).logits.to(device=device)
  230. del model_ref
  231. model_hf = LlamaForCausalLM.from_pretrained(
  232. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  233. torch_dtype=dtype,
  234. device_map="auto",
  235. )
  236. model_hf.eval()
  237. with torch.no_grad():
  238. out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
  239. logits_hf = model_hf(input_ids).logits.to(device=device)
  240. del model_hf
  241. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  242. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  243. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  244. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  245. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  246. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  247. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  248. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  249. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  250. assert (logits - logits_ref).abs().max().item() < 2 * (
  251. logits_hf - logits_ref
  252. ).abs().max().item()
  253. # @pytest.mark.parametrize('model_name', ["7B", "13B"])
  254. @pytest.mark.parametrize("model_name", ["7B"])
  255. @pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
  256. def test_llama_generation(model_name, checkpoint_format):
  257. checkpoint_path = (
  258. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  259. )
  260. dtype = torch.float16
  261. device = "cuda"
  262. config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
  263. config = llama_config_to_gpt2_config(config)
  264. config.use_flash_attn = True
  265. config.fused_bias_fc = True
  266. config.fused_mlp = False # We don't have fused GatedMLP yet
  267. config.fused_dropout_add_ln = True
  268. config.residual_in_fp32 = True
  269. tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
  270. eos_token_id = tokenizer.eos_token_id
  271. torch.manual_seed(0)
  272. batch_size = 1
  273. seqlen = 100
  274. max_length = 150
  275. input_ids = torch.randint(
  276. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  277. )
  278. model_hf = LlamaForCausalLM.from_pretrained(
  279. Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
  280. )
  281. model_hf.eval()
  282. print("HF fp16")
  283. torch.cuda.synchronize()
  284. start = time.time()
  285. out_hf = model_hf.generate(
  286. input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
  287. )
  288. torch.cuda.synchronize()
  289. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  290. del model_hf
  291. # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
  292. model_ref = LlamaForCausalLM.from_pretrained(
  293. Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
  294. )
  295. model_ref.eval()
  296. with torch.no_grad():
  297. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
  298. del model_ref
  299. pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
  300. checkpoint_path, model_name, config, checkpoint_format
  301. )
  302. model = GPTLMHeadModel(config, device=device, dtype=dtype)
  303. model.load_state_dict(pretrained_state_dict)
  304. model.eval()
  305. print("Without CUDA graph")
  306. torch.cuda.synchronize()
  307. start = time.time()
  308. out = model.generate(
  309. input_ids=input_ids,
  310. max_length=max_length,
  311. eos_token_id=eos_token_id,
  312. return_dict_in_generate=True,
  313. output_scores=True,
  314. enable_timing=True,
  315. teacher_outputs=out_hf.sequences,
  316. )
  317. torch.cuda.synchronize()
  318. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  319. # Capture graph outside the timing loop
  320. batch_size, seqlen_og = input_ids.shape
  321. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  322. print("With CUDA graph")
  323. torch.cuda.synchronize()
  324. start = time.time()
  325. out_cg = model.generate(
  326. input_ids=input_ids,
  327. max_length=max_length,
  328. cg=True,
  329. return_dict_in_generate=True,
  330. output_scores=True,
  331. enable_timing=True,
  332. teacher_outputs=out_hf.sequences,
  333. )
  334. torch.cuda.synchronize()
  335. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  336. with torch.no_grad():
  337. logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  338. logits_hf = torch.stack(out_hf.scores, dim=1)
  339. logits = torch.stack(out.scores, dim=1)
  340. logits_cg = torch.stack(out_cg.scores, dim=1)
  341. del model
  342. hf_error = (logits_hf - logits_ref).abs().max().item()
  343. print(f"HF fp16 logits max diff: {hf_error}")
  344. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  345. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
  346. assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
  347. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  348. assert torch.equal(logits_cg, logits)
  349. # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
  350. @pytest.mark.parametrize("world_size", [2])
  351. @pytest.mark.parametrize(
  352. "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
  353. )
  354. def test_llama_parallel_generation(model_name, world_size):
  355. """Check that our implementation matches the HF implementation:
  356. the scores in fp16 should be around the same as the HF scores in fp16, when compared to
  357. the HF scores in fp32.
  358. """
  359. from apex.transformer import parallel_state
  360. checkpoint_path = (
  361. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  362. )
  363. dtype = torch.float16
  364. if "/" in model_name: # Download from HF
  365. config = llama_config_to_gpt2_config(
  366. AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  367. )
  368. else:
  369. config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
  370. config = llama_config_to_gpt2_config(config)
  371. config.use_flash_attn = True
  372. config.fused_bias_fc = True
  373. config.fused_mlp = False # We don't have fused GatedMLP yet
  374. config.fused_dropout_add_ln = True
  375. config.residual_in_fp32 = True
  376. config.pad_vocab_size_multiple = 8 * world_size
  377. config.sequence_parallel = False # Need to set this to False for generation
  378. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
  379. if not torch.distributed.is_initialized():
  380. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  381. device = f"cuda:{torch.distributed.get_rank()}"
  382. assert world_size <= torch.distributed.get_world_size()
  383. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  384. rank = parallel_state.get_tensor_model_parallel_rank()
  385. process_group = parallel_state.get_tensor_model_parallel_group()
  386. torch.manual_seed(0)
  387. batch_size = 1
  388. seqlen = 100
  389. max_length = 150
  390. input_ids = torch.randint(
  391. 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
  392. )
  393. # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
  394. # GPU0 and GPU1 and things would hang
  395. torch.cuda.set_device(device)
  396. if "/" in model_name: # Download from HF
  397. pretrained_state_dict = remap_state_dict_hf_llama(
  398. state_dict_from_pretrained(model_name), config
  399. )
  400. else:
  401. pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
  402. checkpoint_path, model_name, config, checkpoint_format="meta"
  403. )
  404. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  405. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  406. model.eval()
  407. print("Without CUDA graph")
  408. out = model.generate(
  409. input_ids=input_ids,
  410. max_length=max_length,
  411. tensor_parallel=world_size,
  412. vocab_size=config.vocab_size,
  413. # teacher_outputs=out_hf.sequences,
  414. return_dict_in_generate=True,
  415. output_scores=True,
  416. enable_timing=True,
  417. )
  418. # Capture graph outside the timing loop
  419. batch_size, seqlen_og = input_ids.shape
  420. model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
  421. print("With CUDA graph")
  422. out_cg = model.generate(
  423. input_ids=input_ids,
  424. max_length=max_length,
  425. tensor_parallel=world_size,
  426. vocab_size=config.vocab_size,
  427. cg=True,
  428. # teacher_outputs=out_hf.sequences,
  429. return_dict_in_generate=True,
  430. output_scores=True,
  431. enable_timing=True,
  432. )
  433. del model
  434. parallel_state.destroy_model_parallel()
  435. if rank == 0:
  436. # Without device_map, the model is loaded on the CPU, which is very slow
  437. model_hf = LlamaForCausalLM.from_pretrained(
  438. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  439. torch_dtype=dtype,
  440. device_map="auto",
  441. )
  442. model_hf.eval()
  443. print("HF fp16")
  444. torch.cuda.synchronize()
  445. start = time.time()
  446. with torch.inference_mode():
  447. out_hf = model_hf.generate(
  448. input_ids=input_ids,
  449. max_length=max_length,
  450. return_dict_in_generate=True,
  451. output_scores=True,
  452. )
  453. torch.cuda.synchronize()
  454. print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
  455. del model_hf
  456. model_ref = LlamaForCausalLM.from_pretrained(
  457. model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
  458. device_map="auto",
  459. )
  460. model_ref.eval()
  461. with torch.inference_mode():
  462. logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
  463. del model_ref
  464. logits_hf = torch.stack(out_hf.scores, dim=1)
  465. logits = torch.stack(out.scores, dim=1)
  466. logits_cg = torch.stack(out_cg.scores, dim=1)
  467. hf_error = (logits_hf - logits_ref).abs().max().item()
  468. print(f"HF fp16 logits max diff: {hf_error}")
  469. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  470. assert (logits - logits_ref).abs().max().item() < 2 * hf_error
  471. print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
  472. assert torch.equal(logits_cg, logits)
  473. @torch.no_grad()
  474. @pytest.mark.parametrize("world_size", [2])
  475. def test_llama_parallel_uneven_num_heads(world_size):
  476. from apex.transformer import parallel_state
  477. checkpoint_path = (
  478. Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
  479. )
  480. num_attention_heads = world_size + 1
  481. model_name = f"teeny-{num_attention_heads}-heads"
  482. if not torch.distributed.is_initialized():
  483. torch.distributed.init_process_group(backend="nccl", init_method="env://")
  484. device = f"cuda:{torch.distributed.get_rank()}"
  485. assert world_size <= torch.distributed.get_world_size()
  486. parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
  487. rank = parallel_state.get_tensor_model_parallel_rank()
  488. process_group = parallel_state.get_tensor_model_parallel_group()
  489. dtype = torch.float16
  490. llama_config = LlamaConfig(
  491. hidden_size=256
  492. * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
  493. intermediate_size=256 * num_attention_heads * 4,
  494. num_hidden_layers=4,
  495. num_attention_heads=num_attention_heads,
  496. initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test.
  497. )
  498. config = llama_config_to_gpt2_config(llama_config)
  499. config.use_flash_attn = True
  500. config.fused_bias_fc = True
  501. config.fused_mlp = False # We don't have fused GatedMLP yet
  502. config.fused_dropout_add_ln = True
  503. config.residual_in_fp32 = True
  504. torch.manual_seed(0)
  505. batch_size = 2
  506. max_seqlen = 256
  507. seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
  508. input_ids = torch.randint(
  509. 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
  510. )
  511. # Create a shared test model.
  512. if rank == 0:
  513. LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
  514. torch.distributed.barrier()
  515. # Run the standard forward pass test.
  516. pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
  517. checkpoint_path, model_name, config, checkpoint_format="hf"
  518. )
  519. model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
  520. model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
  521. model.eval()
  522. # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
  523. out = model.transformer(input_ids)
  524. out, _ = all_gather_raw(out, process_group=process_group)
  525. out = rearrange(out, "(b s) d -> b s d", b=batch_size)
  526. logits = model(input_ids).logits
  527. logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
  528. logits, _ = all_gather_raw(logits, process_group)
  529. logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
  530. if rank == 0:
  531. model_ref = LlamaForCausalLM.from_pretrained(
  532. Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
  533. )
  534. model_ref = model_ref.to(device=device)
  535. model_ref.eval()
  536. out_ref = model_ref.model(input_ids).last_hidden_state
  537. logits_ref = model_ref(input_ids).logits
  538. del model_ref
  539. model_hf = LlamaForCausalLM.from_pretrained(
  540. Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
  541. )
  542. model_hf.eval()
  543. out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
  544. logits_hf = model_hf(input_ids).logits.to(device=device)
  545. del model_hf
  546. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  547. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  548. print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
  549. print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
  550. assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
  551. print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
  552. print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
  553. print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
  554. print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
  555. assert (logits - logits_ref).abs().max().item() < 2 * (
  556. logits_hf - logits_ref
  557. ).abs().max().item()
  558. if os.path.exists(checkpoint_path / f"{model_name}-hf"):
  559. shutil.rmtree(checkpoint_path / f"{model_name}-hf")