distributed.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from typing import Optional
  2. import torch
  3. from torch import Tensor
  4. from torch.distributed import ProcessGroup
  5. # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
  6. # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
  7. # version of PyTorch. The following 4 lines are for backward compatibility with
  8. # older PyTorch.
  9. if "all_gather_into_tensor" not in dir(torch.distributed):
  10. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  11. if "reduce_scatter_tensor" not in dir(torch.distributed):
  12. torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
  13. # Raw operation, does not support autograd, but does support async
  14. def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
  15. world_size = torch.distributed.get_world_size(process_group)
  16. output = torch.empty(
  17. world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
  18. )
  19. handle = torch.distributed.all_gather_into_tensor(
  20. output, input_.contiguous(), group=process_group, async_op=async_op
  21. )
  22. return output, handle
  23. # Raw operation, does not support autograd, but does support async
  24. def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
  25. world_size = torch.distributed.get_world_size(process_group)
  26. assert input_.shape[0] % world_size == 0
  27. output = torch.empty(
  28. input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
  29. )
  30. handle = torch.distributed.reduce_scatter_tensor(
  31. output, input_.contiguous(), group=process_group, async_op=async_op
  32. )
  33. return output, handle
  34. # Raw operation, does not support autograd, but does support async
  35. def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
  36. input_ = input_.contiguous()
  37. handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
  38. return input_, handle
  39. class AllGatherFunc(torch.autograd.Function):
  40. """Gather the input from sequence parallel region and concatenate."""
  41. @staticmethod
  42. def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
  43. ctx.process_group = process_group
  44. output, _ = all_gather_raw(input_, process_group)
  45. return output
  46. @staticmethod
  47. def backward(ctx, grad_output: Tensor):
  48. grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
  49. return grad_input, None
  50. # Supports autograd, but does not support async
  51. all_gather = AllGatherFunc.apply
  52. class ReduceScatterFunc(torch.autograd.Function):
  53. """Reduce scatter the input from the sequence parallel region and concatenate."""
  54. @staticmethod
  55. def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
  56. ctx.process_group = process_group
  57. output, _ = reduce_scatter_raw(input_, process_group)
  58. return output
  59. @staticmethod
  60. def backward(ctx, grad_output: Tensor):
  61. grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
  62. return grad_input, None
  63. # Supports autograd, but does not support async
  64. reduce_scatter = ReduceScatterFunc.apply
  65. class AllReduceFunc(torch.autograd.Function):
  66. """Gather the input from sequence parallel region and concatenate."""
  67. @staticmethod
  68. def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
  69. ctx.process_group = process_group
  70. output, _ = all_reduce_raw(input_, process_group)
  71. return output
  72. @staticmethod
  73. def backward(ctx, grad_output: Tensor):
  74. return grad_output, None
  75. # Supports autograd, but does not support async
  76. all_reduce = AllReduceFunc.apply
  77. def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
  78. # We want to iterate over parameters with _shared_params=True in the same order,
  79. # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
  80. pamams_shared = {
  81. name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
  82. }
  83. for _, p in sorted(pamams_shared.items()):
  84. with torch.no_grad():
  85. # Broadcast needs src to be global rank, not group rank
  86. torch.distributed.broadcast(
  87. p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
  88. )
  89. # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
  90. def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
  91. # We want to iterate over parameters with _sequence_parallel=True in the same order,
  92. # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
  93. params_seqparallel = {
  94. name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
  95. }
  96. grads = [p.grad for _, p in sorted(params_seqparallel.items())]
  97. if grads:
  98. with torch.no_grad():
  99. coalesced = torch._utils._flatten_dense_tensors(grads)
  100. torch.distributed.all_reduce(coalesced, group=process_group)
  101. for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
  102. buf.copy_(synced)
  103. def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
  104. """Get the dim for the local rank derived from splitting dim on world_size processes.
  105. The split may not be even across the world_size processes.
  106. """
  107. multiple = dim // multiple_of
  108. div = multiple // world_size
  109. mod = multiple % world_size
  110. local_multiple = div + int(local_rank < mod)
  111. return local_multiple * multiple_of