# Original Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# https://github.com/NVIDIA/Megatron-LM/blob/52e6368/megatron/core/tensor_parallel/cross_entropy.py
# Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.distributed as dist
from torch import nn
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, _):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
class _VocabParallelCrossEntropy(torch.autograd.Function):
# pylint: disable=abstract-method, arguments-differ
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, group=None):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies
get_vocab_range = vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(
start=0, end=logits_2d.size()[0], device=logits_2d.device
)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
# We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
# = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
# = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
# = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
# = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
# = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
# From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
else:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None, None
[docs]def vocab_parallel_cross_entropy(
vocab_parallel_logits, target, label_smoothing=0.0, group=None
):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Parameters
----------
vocab_parallel_logits
logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target
correct vocab ids of dimension [sequence_length, micro_batch_size]
label_smoothing
smoothing factor, must be in range [0.0, 1.0)
default is no smoothing (=0.0)
group
torch.distributed group
"""
# pylint: disable=missing-type-doc
return _VocabParallelCrossEntropy.apply(
vocab_parallel_logits, target, label_smoothing, group
)
[docs]class ParallelCrossEntropy(nn.Module):
def __init__(self, group=None):
super().__init__()
self.group = group
[docs] def forward(self, outputs, labels):
return vocab_parallel_cross_entropy(outputs, labels, group=self.group)