Source code for miprometheus.problems.seq_to_seq.seq_to_seq_problem

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
seq_to_seq_problem.py: contains base class for all sequence to sequence problems.

"""

__author__ = "Tomasz Kornuta & Vincent Marois"

from miprometheus.problems.problem import Problem
import torch


[docs]class SeqToSeqProblem(Problem): """ Class representing base class for all sequential problems. """
[docs] def __init__(self, params): """ Initializes problem object. Calls base constructor. :param params: Dictionary of parameters (read from configuration ``.yaml`` file). """ super(SeqToSeqProblem, self).__init__(params) # "Default" problem name. self.name = 'SeqToSeqProblem' # Set default data_definitions dict for all Seq2Seq problems. self.data_definitions = {'sequences': {'size': [-1, -1, -1], 'type': [torch.Tensor]}, 'targets': {'size': [-1, -1, -1], 'type': [torch.Tensor]}, 'masks': {'size': [-1, -1, 1], 'type': [torch.Tensor]}, 'sequences_length': {'size': [-1, 1], 'type': [torch.Tensor]} } # Check if predictions/targets should be masked (DEFAULT: True). params.add_default_params({'use_mask': True}) self.use_mask = params["use_mask"]
[docs] def evaluate_loss(self, data_dict, logits): """ Calculates accuracy equal to mean number of correct predictions in a given batch. WARNING: Applies mask to both logits and targets! :param data_dict: DataDict({'sequences', 'sequences_length', 'targets', 'mask'}). :param logits: Predictions being output of the model. """ # Check if mask should be is used - if so, use the correct loss # function. if self.use_mask: loss = self.loss_function( logits, data_dict['targets'], data_dict['masks']) else: loss = self.loss_function(logits, data_dict['targets']) return loss
if __name__ == '__main__': from miprometheus.utils.param_interface import ParamInterface sample = SeqToSeqProblem(ParamInterface())[0] # equivalent to ImageTextToClassProblem(params={}).__getitem__(index=0) print(repr(sample))