|
@@ -137,10 +137,12 @@ class GemmaRMSNorm(CustomOp):
|
|
|
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
|
|
self.variance_epsilon = eps
|
|
|
|
|
|
- def forward_native(
|
|
|
- self,
|
|
|
+ @staticmethod
|
|
|
+ def forward_static(
|
|
|
+ weight: torch.Tensor,
|
|
|
+ variance_epsilon: float,
|
|
|
x: torch.Tensor,
|
|
|
- residual: Optional[torch.Tensor] = None,
|
|
|
+ residual: Optional[torch.Tensor],
|
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
orig_dtype = x.dtype
|
|
@@ -150,17 +152,31 @@ class GemmaRMSNorm(CustomOp):
|
|
|
|
|
|
x = x.float()
|
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
|
- x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
+ x = x * torch.rsqrt(variance + variance_epsilon)
|
|
|
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
|
- x = x * (1.0 + self.weight.float())
|
|
|
+ x = x * (1.0 + weight.float())
|
|
|
x = x.to(orig_dtype)
|
|
|
return x if residual is None else (x, residual)
|
|
|
|
|
|
+ def forward_native(
|
|
|
+ self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ residual: Optional[torch.Tensor] = None,
|
|
|
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
+ """PyTorch-native implementation equivalent to forward()."""
|
|
|
+ return self.forward_static(self.weight.data, self.variance_epsilon, x,
|
|
|
+ residual)
|
|
|
+
|
|
|
def forward_cuda(
|
|
|
self,
|
|
|
x: torch.Tensor,
|
|
|
residual: Optional[torch.Tensor] = None,
|
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
|
- # TODO: Implement an optimized kernel for GemmaRMSNorm.
|
|
|
+ if torch.compiler.is_compiling():
|
|
|
+ return self.forward_native(x, residual)
|
|
|
+ if not getattr(self, "_is_compiled", False):
|
|
|
+ self.forward_static = torch.compile( # type: ignore
|
|
|
+ self.forward_static)
|
|
|
+ self._is_compiled = True
|
|
|
return self.forward_native(x, residual)
|