Source code for miprometheus.models.encoder_solver.es_lstm_model

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""es_lstm_model.py: Neural Network implementing Encoder-Decoder/Solver architecture"""
__author__ = "Tomasz Kornuta"

from enum import Enum
import torch
from torch import nn

from miprometheus.models.sequential_model import SequentialModel


[docs]class EncoderSolverLSTM(SequentialModel): """ Class representing the Encoder-Solver architecture using LSTM cells as both encoder and solver modules. """
[docs] def __init__(self, params, problem_default_values_={}): """ Constructor. Initializes parameters on the basis of dictionary passed as argument. :param params: Local view to the Parameter Regsitry ''model'' section. :param problem_default_values_: Dictionary containing key-values received from problem. """ # Call base constructor. Sets up default values etc. super(EncoderSolverLSTM, self).__init__(params, problem_default_values_) # Model name. self.name = 'EncoderSolverLSTM' # Parse default values received from problem. self.params.add_default_params({ 'input_item_size': problem_default_values_['input_item_size'], 'output_item_size': problem_default_values_['output_item_size'], 'encoding_bit': problem_default_values_['store_bit'], 'solving_bit': problem_default_values_['recall_bit'] }) self.input_item_size = params["input_item_size"] self.output_item_size = params["output_item_size"] # Indices of control bits triggering encoding/decoding. self.encoding_bit = params['encoding_bit'] # Def: 0 self.solving_bit = params['solving_bit'] # Def: 1 self.hidden_state_size = params["hidden_state_size"] # Create the Encoder. self.encoder = nn.LSTMCell(self.input_item_size, self.hidden_state_size) # Create the Decoder/Solver. self.solver = nn.LSTMCell(self.input_item_size, self.hidden_state_size) # Output linear layer. self.output = nn.Linear(self.hidden_state_size, self.output_item_size) self.modes = Enum('Modes', ['Encode', 'Solve'])
[docs] def init_state(self, batch_size): """ Returns 'zero' (initial) state. :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple (hidden, memory cell). """ dtype = self.app_state.dtype # Initialize the hidden state. h_init = torch.zeros(batch_size, self.hidden_state_size, requires_grad=False).type(dtype) # Initialize the memory cell state. c_init = torch.zeros(batch_size, self.hidden_state_size, requires_grad=False).type(dtype) # Pack and return a tuple. return (h_init, c_init)
[docs] def forward(self, data_dict): """ Forward function requires that the data_dict will contain at least "sequences" :param data_dict: DataDict containing at least: - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE] :returns: Predictions (logits) being a tensor of size [BATCH_SIZE x LENGTH_SIZE x OUTPUT_SIZE]. """ # Get dtype. dtype = self.app_state.dtype # Unpack dict. inputs_BxSxI = data_dict['sequences'] # Get batch size. batch_size = inputs_BxSxI.size(0) # Initialize state variables. (h, c) = self.init_state(batch_size) # Logits container. logits = [] for x in inputs_BxSxI.chunk(inputs_BxSxI.size(1), dim=1): # Squeeze x. x = x.squeeze(1) # Switch between the encoder and decoder modes. It will stay in # this mode till it hits the opposite kind of marker if x[0, self.solving_bit] and not x[0, self.encoding_bit]: mode = self.modes.Solve elif x[0, self.encoding_bit] and not x[0, self.solving_bit]: mode = self.modes.Encode elif x[0, self.encoding_bit] and x[0, self.solving_bit]: print('Error: both encoding and decoding bits were true') exit(-1) if mode == self.modes.Solve: h, c = self.solver(x, (h, c)) elif mode == self.modes.Encode: h, c = self.encoder(x, (h, c)) # Collect logits. logit = self.output(h) logits += [logit] # Stack logits along the temporal (sequence) axis. logits = torch.stack(logits, 1) return logits