Source code for miprometheus.utils.loss.masked_bce_with_logits_loss

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
masked_bce_with_logits_loss.py: contains masked binary cross entropy loss function.
"""
__author__ = "Ryan L. McAvoy"

import torch
from torch.nn import Module
from miprometheus.utils.app_state import AppState


[docs]class MaskedBCEWithLogitsLoss(Module): """ Calculates the binary cross entropy for batches with different numbers of outputs for the samples. """
[docs] def __init__(self, weight=None): """ Constructor for the ``MaskedBCEWithLogitsLoss``. Defines the inner loss as ``BCEWithLogitsLoss``. :param weight: a manual rescaling weight given to each class. \ If given, has to be a Tensor of size `C` :type weight: Tensor, optional """ super(MaskedBCEWithLogitsLoss, self).__init__() self.loss_function = torch.nn.BCEWithLogitsLoss(reduce=False)
# for pytorch 4.1 # self.loss_function = torch.nn.BCELossWithLogits(reduction = "none")
[docs] def forward(self, logits, targets, mask): """ Calculates loss accounting for different numbers of output per sample. :param logits: Logits being output by the model. [batch, classes, sequence]. :type logits: torch.tensor. :param targets: Targets [batch, sequence]. :type targets: torch.LongTensor :param mask: Mask [batch, sequence]. :type mask: torch.ByteTensor :return: loss value. """ # Calculate the loss per element in the sequence loss_per_element = self.loss_function(logits, targets) # if the loss has one extra dimenison then you need an extra unit dimension # to multiply element by element mask_float = mask.type(AppState().dtype) if len(mask.shape) < len(loss_per_element.shape): mask_float = mask_float.unsqueeze(-1) # Set the loss per element to zero for unneeded output masked_loss_per = mask_float * loss_per_element # obtain the number of non-zero elements in the mask. # nonzero() returns the indices so you have to divide by the number of dimensions # The mask lacks the last dimension of the targets so needs to be # scaled up size = mask.nonzero().numel() / len(mask.shape) * logits.shape[-1] loss = torch.sum(masked_loss_per) / size return loss
[docs] def masked_accuracy(self, logits, targets, mask): """ Calculates accuracy equal to mean number of correct predictions in a \ given batch. .. warning:: Applies ``mask`` to both ``logits`` and ``targets``. :param logits: Logits being output by the model. [batch, classes, sequence]. :type logits: torch.tensor. :param targets: Targets [batch, sequence]. :type targets: torch.LongTensor :param mask: Mask [batch, sequence]. :type mask: torch.ByteTensor :return: accuracy value. """ # calculate the accuracy per bit in the sequences acc_per = 1 - torch.abs(torch.round(torch.nn.functional.sigmoid(logits)) - targets) mask_float = mask.type(AppState().dtype) if len(mask.shape) < len(logits.shape): mask_float = mask_float.unsqueeze(-1) # The mask lacks the last dimension of the targets so needs to be # scaled up size = mask.nonzero().numel() / len(mask.shape) * logits.shape[-1] masked_acc_per = mask_float * acc_per accuracy = masked_acc_per.sum().item() / size return accuracy