Source code for slapo.op.fused_bias
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Fuse bias with the subsequent ops, such as activation function or dropout."""
# pylint: disable=abstract-method
from __future__ import annotations
import math
import torch
from torch.nn import functional as F
[docs]class BiasGeLUFunction(torch.autograd.Function):
"""Bias+GeLU. Copied from Megatron-LM."""
# pylint: disable=no-self-argument, arguments-differ
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
return ff * g
[docs] @staticmethod
# bias is an optional argument
def forward(ctx, inp, bias):
ctx.save_for_backward(inp, bias)
return BiasGeLUFunction.bias_gelu(bias, inp)
[docs] @staticmethod
def backward(ctx, grad_output):
inp, bias = ctx.saved_tensors
tmp = BiasGeLUFunction.bias_gelu_back(grad_output, bias, inp)
return tmp, tmp
[docs]def new_gelu(inp):
"""New GELU activation function copied from HuggingFace transformers."""
return (
0.5
* inp
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (inp + 0.044715 * torch.pow(inp, 3.0))
)
)
)
def bias_new_gelu(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return new_gelu(inp + bias)
def bias_dropout(
x: torch.Tensor,
bias: torch.Tensor,
p: float = 0.5,
training: bool = True,
inplace: bool = False,
) -> torch.Tensor:
return F.dropout(x + bias, p=p, training=training, inplace=inplace)