moe_pallas.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import torch
  2. import torch.nn.functional as F
  3. from torch_xla.experimental.custom_kernel import _histogram
  4. def fused_moe(
  5. hidden_states: torch.Tensor,
  6. w1: torch.Tensor,
  7. w2: torch.Tensor,
  8. gating_output: torch.Tensor,
  9. topk: int,
  10. renormalize: bool,
  11. ) -> torch.Tensor:
  12. """
  13. Args:
  14. hidden_states: [*, hidden_size]
  15. w1: [num_experts, intermediate_size * 2, hidden_size]
  16. w2: [num_experts, hidden_size, intermediate_size]
  17. gating_output: [*, num_experts]
  18. """
  19. orig_shape = hidden_states.shape
  20. hidden_size = hidden_states.shape[-1]
  21. num_tokens = hidden_states.shape[:-1].numel()
  22. num_experts = w1.shape[0]
  23. intermediate_size = w2.shape[-1]
  24. device = hidden_states.device
  25. dtype = hidden_states.dtype
  26. assert (num_tokens * topk) % 16 == 0, (
  27. "The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
  28. f"16 but got {num_tokens * topk}")
  29. hidden_states = hidden_states.view(num_tokens, hidden_size)
  30. gating_output = gating_output.view(num_tokens, num_experts)
  31. topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
  32. topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
  33. if renormalize:
  34. topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
  35. topk_weights = topk_weights.to(dtype)
  36. topk_indices = topk_indices.flatten()
  37. topk_argsort_indices = topk_indices.argsort()
  38. topk_argsort_revert_indices = topk_argsort_indices.argsort()
  39. token_indices = torch.arange(num_tokens,
  40. device=device).repeat_interleave(topk)
  41. token_indices = token_indices[topk_argsort_indices]
  42. group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
  43. # NOTE: The GMM Pallas kernel requires a different weight layout
  44. # from HF Transformers.
  45. w1 = w1.transpose(1, 2)
  46. w2 = w2.transpose(1, 2)
  47. x = hidden_states[token_indices]
  48. x = torch.ops.xla.gmm(x, w1, group_sizes)
  49. x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
  50. x = torch.ops.xla.gmm(x, w2, group_sizes)
  51. x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
  52. x = x * topk_weights.unsqueeze_(dim=-1)
  53. x = x.sum(dim=-2)
  54. x = x.reshape(orig_shape)
  55. return x