#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# The MIT License (MIT)
#
# Copyright (c) 2017 Sean Robertson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ------------------------------------------------------------------------------
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
text_to_text_problem.py: abstract base class for text to text sequential problems, e.g. machine translation.
"""
__author__ = "Vincent Marois"
import unicodedata
import re
import torch
import torch.nn as nn
from miprometheus.problems.seq_to_seq.seq_to_seq_problem import SeqToSeqProblem
# global tokens
PAD_token = 0
SOS_token = 1
EOS_token = 2
[docs]class TextToTextProblem(SeqToSeqProblem):
"""
Base class for text to text sequential problems.
Provides some basic features useful in all problems of such
type.
"""
[docs] def __init__(self, params):
"""
Initializes problem object. Calls base ``SeqToSeqProblem`` constructor.
Sets ``nn.NLLLoss()`` as default loss function.
:param params: Dictionary of parameters (read from configuration ``.yaml`` file).
"""
super(TextToTextProblem, self).__init__(params)
# set default loss function - negative log likelihood and ignore
# padding elements.
self.loss_function = nn.NLLLoss(size_average=True, ignore_index=0)
# set default data_definitions dict
self.data_definitions = {'inputs': {'size': [-1, -1, -1], 'type': [torch.Tensor]},
'inputs_length': {'size': [-1, 1], 'type': [list, int]},
'inputs_text': {'size': [-1, 1], 'type': [list, str]},
'targets': {'size': [-1, -1, -1], 'type': [torch.Tensor]},
'targets_length': {'size': [-1, 1], 'type': [list, int]},
'outputs_text': {'size': [-1, 1], 'type': [list, str]},
}
# default values likely to be useful to a model.
# setting the fields for the vocabulary sets sizes to None for now.
# TODO: other fields to consider?
self.default_values = {'input_voc_size': None,
'output_voc_size': None,
'embedding_dim': None}
self.input_lang = None
self.output_lang = None
[docs] def compute_BLEU_score(self, data_dict, logits):
"""
Compute the BLEU score in order to evaluate the translation quality
(equivalent of accuracy).
.. note::
Reference paper: http://www.aclweb.org/anthology/P02-1040.pdf
Implementation inspired from https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
To handle all samples within a batch, we accumulate the individual BLEU score for each pair\
of sentences and average over the batch size.
:param data_dict: DataDict({'inputs', 'inputs_length', 'inputs_text', 'targets', 'targets_length', 'outputs_text'}).
:param logits: Predictions of the model.
:return: Average BLEU Score for the batch ( 0 < BLEU < 1).
"""
# get most probable words indexes for the batch
_, top_indexes = logits.topk(k=1, dim=-1)
logits = top_indexes.squeeze()
batch_size = logits.shape[0]
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# retrieve target sentences from TextAuxTuple
targets_text = []
for sentence in data_dict['targets_text']:
targets_text.append(sentence.split())
# retrieve text sentences from the logits (which should be tensors of
# indexes)
logits_text = []
for logit in logits:
logits_text.append(
[self.output_lang.index2word[index.item()] for index in logit])
bleu_score = 0
for i in range(batch_size):
# compute bleu score and use a smoothing function
bleu_score += sentence_bleu([targets_text[i]],
logits_text[i],
smoothing_function=SmoothingFunction().method1)
return round(bleu_score / batch_size, 4)
[docs] def evaluate_loss(self, data_dict, logits):
"""
Computes loss.
By default, the loss function is the Negative Log Likelihood function.
The input given through a forward call is expected to contain log-probabilities (LogSoftmax) of each class.
The input has to be a Tensor of size either (batch_size, C) or (batch_size, C, d1, d2,...,dK) with K ≥ 2 for the
K-dimensional case.
The target that this loss expects is a class index (0 to C-1, where C = number of classes).
:param data_dict: DataDict({'inputs', 'inputs_length', 'inputs_text', 'targets', 'targets_length', 'outputs_text'}).
:param logits: Predictions of the model.
:return: loss
"""
loss = self.loss_function(logits.transpose(1, 2), data_dict['targets'])
return loss
[docs] def add_statistics(self, stat_col):
"""
Add BLEU score to a ``StatisticsCollector``.
:param stat_col: Statistics collector.
:type stat_col: ``StatisticsCollector``
"""
stat_col.add_statistic('bleu_score', '{:4.5f}')
[docs] def collect_statistics(self, stat_col, data_dict, logits):
"""
Collects BLEU score.
:param stat_col: ``StatisticsCollector``
:param data_dict: DataDict({'inputs', 'inputs_length', 'inputs_text', 'targets', 'targets_length', 'outputs_text'}).
:param logits: Predictions of the model.
"""
stat_col['bleu_score'] = self.compute_BLEU_score(data_dict, logits)
[docs] def show_sample(self, data_dict, sample=0):
"""
Shows the sample (both input and target sequences) using matplotlib.
Elementary visualization.
:param data_dict: DataDict({'inputs', 'inputs_length', 'inputs_text', 'targets', 'targets_length', 'outputs_text'}).
:param sample: Number of sample in a batch (default: 0)
.. note::
TODO
"""
pass
# ----------------------
# The following are helper functions for data pre-processing in the case
# of a translation task
[docs] def unicode_to_ascii(self, s):
"""
Turn a Unicode string to plain ASCII.
See: http://stackoverflow.com/a/518232/2809427.
:param s: Unicode string.
:return: plain ASCII string.
"""
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
[docs] def normalize_string(self, s):
"""
Lowercase, trim, and remove non-letter characters in string s.
:param s: string.
:return: normalized string.
"""
s = self.unicode_to_ascii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
[docs] def indexes_from_sentence(self, lang, sentence):
"""
Construct a list of indexes using a 'vocabulary index' from a specified
Lang class instance for the specified sentence (see ``Lang`` class below).
:param lang: instance of the ``Lang`` class, having a ``word2index`` dict.
:type lang: Lang
:param sentence: string to convert word for word to indexes, e.g. "The black cat is eating."
:type sentence: str
:return: list of indexes.
"""
seq = [lang.word2index[word] for word in sentence.split(' ')] + [EOS_token]
return seq
[docs] def tensor_from_sentence(self, lang, sentence):
"""
Uses ``indexes_from_sentence()`` to create a tensor of indexes with the
EOS token.
:param lang: instance of the ``Lang`` class, having a ``word2index`` dict.
:type lang: Lang
:param sentence: string to convert word for word to indexes, e.g. "The black cat is eating."
:type sentence: str
:return: tensor of indexes, terminated by the EOS token.
"""
indexes = self.indexes_from_sentence(lang, sentence)
return torch.tensor(indexes).type(self.app_state.LongTensor)
[docs] def tensors_from_pair(self, pair, input_lang, output_lang):
"""
Creates a tuple of tensors of indexes from a pair of sentences.
:param pair: input & output languages sentences
:type pair: tuple
:param input_lang: instance of the ``Lang`` class, having a ``word2index`` dict, representing the input language.
:type lang: Lang
:param output_lang: instance of the ``Lang`` class, having a ``word2index`` dict, representing the output language.
:type lang: Lang
:return: tuple of tensors of indexes.
"""
input_tensor = self.tensor_from_sentence(input_lang, pair[0])
target_tensor = self.tensor_from_sentence(output_lang, pair[1])
return [input_tensor, target_tensor]
[docs] def tensors_from_pairs(self, pairs, input_lang, output_lang):
"""
Returns a list of tuples of tensors of indexes from a list of pairs of
sentences. Uses ``tensors_from_pair()``.
:param pairs: sentences pairs
:type pairs: list
:param input_lang: instance of the class Lang, having a word2index dict, representing the input language.
:type lang: Lang
:param output_lang: instance of the class Lang, having a word2index dict, representing the output language.
:type lang: Lang
:return: list of tensors of indexes.
"""
return [self.tensors_from_pair(pair, input_lang, output_lang) for pair in pairs]
[docs]class Lang(object):
"""
Simple helper class allowing to represent a language in a translation task.
It will contain for instance a vocabulary index (``word2index`` dict) & keep
track of the number of words in the language.
This class is useful as each word in a language will be represented as a one-hot vector: a giant vector of zeros
except for a single one (at the index of the word). The dimension of this vector is potentially very high, hence it
is generally useful to trim the data to only use a few thousand words per language.
The inputs and targets of the associated sequence to sequence networks will be sequences of indexes, each item
representing a word. The attributes of this class (``word2index``, ``index2word``, ``word2count``) are useful to\
keep track of this.
"""
[docs] def __init__(self, name):
"""
Constructor.
:param name: string to name the language (e.g. french, english)
"""
self.name = name
self.word2index = {"PAD": 0, "SOS": 1, "EOS": 2} # dict 'word': index
# keep track of the occurrence of each word in the language. Can be
# used to replace rare words.
self.word2count = {}
# dict 'index': 'word', initializes with PAD, EOS, SOS tokens
self.index2word = {0: "PAD", 1: "SOS", 2: "EOS"}
# Number of words in the language. Start by counting PAD, EOS, SOS
# tokens.
self.n_words = 3
[docs] def add_sentence(self, sentence):
"""
Process a sentence using ``add_word()``.
:param sentence: sentence to be added to the language.
:type sentence: str
"""
for word in sentence.split(' '):
self.add_word(word)
[docs] def add_word(self, word):
"""
Add a word to the vocabulary set: update word2index, word2count,
index2words & n_words.
:param word: word to be added.
:type word: str
"""
if word not in self.word2index: # if the current word has not been seen before
# create a new entry in word2index
self.word2index[word] = self.n_words
self.word2count[word] = 1 # count first occurrence of this word
# create a new entry in index2word
self.index2word[self.n_words] = word
self.n_words += 1 # increment total number of words in the language
else: # this word has been seen before, simply update its occurrence
self.word2count[word] += 1