Source code for miprometheus.models.encoder_solver.mas_cell

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""mas_cell.py: pytorch module implementing single (recurrent) cell of Memory-Augmented Solver"""
__author__ = "Tomasz Kornuta"

import torch
import collections
from torch.nn import Module

# Set logging level.
import logging
logger = logging.getLogger('MAS-Cell')
# logging.basicConfig(level=logging.DEBUG)

from miprometheus.models.controllers.controller_factory import ControllerFactory
from miprometheus.models.encoder_solver.mas_interface import MASInterface

# Helper collection type.
_MASCellStateTuple = collections.namedtuple(
    'MASCellStateTuple',
    ('ctrl_state',
     'interface_state',
     'memory_state',
     'read_vector'))


[docs]class MASCellStateTuple(_MASCellStateTuple): """ Tuple used by MAS Cells for storing current/past state information. """ __slots__ = ()
[docs]class MASCell(Module): """ Class representing a single Memory-Augmented Decoder cell. """
[docs] def __init__(self, params): """ Cell constructor. Cell creates controller and interface. Assumes that memory will be initialized by the encoder. :param params: Dictionary of parameters. """ # Call constructor of base class. super(MASCell, self).__init__() # Parse parameters. # Set input and output sizes. self.input_size = params['input_item_size'] self.output_size = params['output_item_size'] # Get controller hidden state size. self.controller_hidden_state_size = params['controller'][ 'hidden_state_size'] # Controller - entity that processes input and produces hidden state of # the MAS cell. ext_controller_inputs_size = self.input_size # Create dictionary wirh controller parameters. controller_params = { "name": params['controller']['name'], "input_size": ext_controller_inputs_size, "output_size": self.controller_hidden_state_size, "non_linearity": params['controller']['non_linearity'], "num_layers": params['controller']['num_layers'] } # Build the controller. self.controller = ControllerFactory.build(controller_params) # Interface - entity responsible for accessing the memory. self.interface = MASInterface(params) # Layer that produces output on the basis of hidden state and vector # read from the memory. ext_hidden_size = self.controller_hidden_state_size + \ params['memory']['num_content_bits'] self.hidden2output = torch.nn.Linear(ext_hidden_size, self.output_size)
[docs] def init_state(self, final_enc_memory_BxAxC, final_enc_attention_BxAx1): """ Initializes the solver cell state depending on the last memory state. Recursively initialization: controller, interface. :param encoder_state: Last state of MAE cell. :returns: Initial state tuple - object of MASCellStateTuple class. """ # Get number of memory addresses. batch_size = final_enc_memory_BxAxC.size(0) num_memory_addresses = final_enc_memory_BxAxC.size(1) # Initialize controller state. ctrl_init_state = self.controller.init_state(batch_size) # Initialize interface state. interface_init_state = self.interface.init_state( batch_size, num_memory_addresses, final_enc_attention_BxAx1) # Initialize read vectors - one for every head. # Unpack cell state. (init_read_attention_BxAx1, _, _, _) = interface_init_state # Read a vector from memory using the initial attention. read_vector_BxC = self.interface.read_from_memory( init_read_attention_BxAx1, final_enc_memory_BxAxC) # Pack and return a tuple. cell_state = MASCellStateTuple( ctrl_init_state, interface_init_state, final_enc_memory_BxAxC, read_vector_BxC) return cell_state
[docs] def init_state_with_encoder_state(self, final_enc_cell_state): """ Creates 'zero' (initial) state on the basis of he previous cell state. "Recursivelly" calls controller and interface initialization. :param final_enc_cell_state: Last state of MAE cell. :returns: Initial state tuple - object of MASCellStateTuple class. """ #batch_size, final_encoder_memory_BxAxC, final_encoder_attention_BxAx1 # Unpack encoder state tuple. (enc_ctrl_state, enc_interface_state, enc_memory_BxAxC) = final_enc_cell_state (enc_attention_BxAx1, _) = enc_interface_state # Get number of memory addresses. batch_size = enc_memory_BxAxC.size(0) num_memory_addresses = enc_memory_BxAxC.size(1) # Initialize interface state. interface_init_state = self.interface.init_state( batch_size, num_memory_addresses, enc_attention_BxAx1) # Initialize read vectors - one for every head. # Unpack cell state. (init_read_attention_BxAx1, _, _, _) = interface_init_state # Read a vector from memory using the initial attention. read_vector_BxC = self.interface.read_from_memory( init_read_attention_BxAx1, enc_memory_BxAxC) # Pack and return a tuple. cell_state = MASCellStateTuple( enc_ctrl_state, interface_init_state, enc_memory_BxAxC, read_vector_BxC) return cell_state
[docs] def forward(self, inputs_BxI, prev_cell_state): """ Forward function of MAS cell. :param inputs_BxI: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] :param prev_cell_state: a MASCellStateTuple tuple, containing previous state of the cell. :returns: an output Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and MASCellStateTuple tuple containing current cell state. """ # Unpack previous cell state. (prev_ctrl_state_tuple, prev_interface_state_tuple, prev_memory_BxAxC, _) = prev_cell_state controller_input = inputs_BxI # Execute controller forward step. ctrl_output_BxH, ctrl_state_tuple = self.controller( controller_input, prev_ctrl_state_tuple) # Execute interface forward step. read_vector_BxC, memory_BxAxC, interface_state_tuple = self.interface( ctrl_output_BxH, prev_memory_BxAxC, prev_interface_state_tuple) #logger.warning("ctrl_output_BxH {}:\n {}".format(ctrl_output_BxH.size(), ctrl_output_BxH)) #logger.warning("read_vector_BxC {}:\n {}".format(read_vector_BxC.size(), read_vector_BxC)) # Output layer - takes controller output concateneted with new read # vectors. ext_hidden = torch.cat([ctrl_output_BxH, read_vector_BxC], dim=1) logits_BxO = self.hidden2output(ext_hidden) #logger.warning("logits_BxO {}:\n {}".format(logits_BxO.size(), logits_BxO)) # Pack current cell state. cell_state_tuple = MASCellStateTuple( ctrl_state_tuple, interface_state_tuple, memory_BxAxC, read_vector_BxC) # Return logits and current cell state. return logits_BxO, cell_state_tuple