Source code for miprometheus.models.ntm.ntm_cell

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ntm_cell.py: pytorch module implementing single (recurrent) cell of Neural Turing Machine"""
__author__ = "Tomasz Kornuta"

import torch
import collections
from torch.nn import Module

from miprometheus.models.controllers.controller_factory import ControllerFactory
from miprometheus.models.ntm.ntm_interface import NTMInterface

# Helper collection type.
_NTMCellStateTuple = collections.namedtuple(
    'NTMCellStateTuple',
    ('ctrl_state',
     'interface_state',
     'memory_state',
     'read_vectors'))


[docs]class NTMCellStateTuple(_NTMCellStateTuple): """ Tuple used by NTM Cells for storing current/past state information. """ __slots__ = ()
[docs]class NTMCell(Module): """ Class representing a single NTM cell. """
[docs] def __init__(self, params): """ Cell constructor. Cell creates controller and interface. It also initializes memory "block" that will be passed between states. :param params: Dictionary of parameters. """ # Call constructor of base class. super(NTMCell, self).__init__() # Parse parameters. # Set input and output sizes. self.input_size = params['input_item_size'] self.output_size = params['output_item_size'] # Get controller hidden state size. self.controller_hidden_state_size = params['controller'][ 'hidden_state_size'] # Get memory parameters - required by initialization of ext_controller # input size. self.num_memory_content_bits = params['memory']['num_content_bits'] # Get number of read heads. self.interface_num_read_heads = params['interface']['num_read_heads'] # Controller - entity that processes input and produces hidden state of the ntm cell. # controller_input_size = input_size + read_vector_size * num_read_heads ext_controller_inputs_size = self.input_size + \ self.num_memory_content_bits * self.interface_num_read_heads # Create dictionary wirh controller parameters. controller_params = params['controller'] controller_params.add_default_params({ "input_size": ext_controller_inputs_size, "output_size": self.controller_hidden_state_size }) # Build the controller. self.controller = ControllerFactory.build(controller_params) # Interface - entity responsible for accessing the memory. self.interface = NTMInterface(params) # Layer that produces output on the basis of... hidden state? ext_hidden_size = self.controller_hidden_state_size + \ self.num_memory_content_bits * self.interface_num_read_heads self.hidden2output = torch.nn.Linear(ext_hidden_size, self.output_size)
[docs] def init_state(self, init_memory_BxAxC): """ Returns 'zero' (initial) state. "Recursivelly" calls controller and interface initialization. :param init_memory_BxAxC: Initial memory. :returns: Initial state tuple - object of NTMCellStateTuple class. """ batch_size = init_memory_BxAxC.size(0) num_memory_addresses = init_memory_BxAxC.size(1) # Initialize controller state. ctrl_init_state = self.controller.init_state(batch_size) # Initialize interface state. interface_init_state = self.interface.init_state( batch_size, num_memory_addresses) # Initialize read vectors - one for every head. # Unpack cell state. (init_read_state_tuples, _) = interface_init_state (init_read_attentions_BxAx1_H, _, _, _) = zip(*init_read_state_tuples) read_vectors_BxC_H = [] for h in range(self.interface_num_read_heads): # Read vector [BATCH_SIZE x CONTENT_BITS] #read_vectors_BxC_H.append(torch.zeros((batch_size, self.num_memory_content_bits)).type(dtype)) # Read vectors from memory using the initial attention. read_vectors_BxC_H.append(self.interface.read_from_memory( init_read_attentions_BxAx1_H[h], init_memory_BxAxC)) # Pack and return a tuple. ntm_state = NTMCellStateTuple( ctrl_init_state, interface_init_state, init_memory_BxAxC, read_vectors_BxC_H) return ntm_state
# def init_state_from_prev_state(self, prev_cell_state): # """ # Creates 'zero' (initial) state on the basis of he previous cell state. # "Recursivelly" calls controller and interface initialization. # #:param prev_cell_state: Previous cell state. #:returns: Initial state tuple - object of NTMCellStateTuple class. # """ # Unpack previous cell state #prev_ctrl_state, prev_interface_state, prev_memory_BxAxC, prev_read_vectors_BxC_H = prev_cell_state # Initialize controller state. #ctrl_init_state = self.controller.init_state_from_state(prev_ctrl_state) # Initialize interface state. #interface_init_state = self.interface.init_state_from_state(prev_interface_state) # Pack and return a tuple. #ntm_state = NTMCellStateTuple(prev_ctrl_state, prev_interface_state, prev_memory_BxAxC, prev_read_vectors_BxC_H) # return ntm_state
[docs] def forward(self, inputs_BxI, prev_cell_state): """ Forward function of NTM cell. :param inputs_BxI: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] :param prev_cell_state: a NTMCellStateTuple tuple, containing previous state of the cell. :returns: an output Tensor of size [BATCH_SIZE x OUTPUT_SIZE] and NTMCellStateTuple tuple containing current cell state. """ # Unpack previous cell state. (prev_ctrl_state_tuple, prev_interface_state_tuple, prev_memory_BxAxC, prev_read_vectors_BxC_H) = prev_cell_state # Concatenate inputs with previous read vectors [BATCH_SIZE x (INPUT + NUM_HEADS * MEMORY_CONTENT_BITS)] #print("prev_read_vectors_BxC_H =", prev_read_vectors_BxC_H[0].size()) prev_read_vectors = torch.cat(prev_read_vectors_BxC_H, dim=1) #print("inputs_BxI =", inputs_BxI.size()) #print("prev_read_vectors =", prev_read_vectors.size()) controller_input = torch.cat((inputs_BxI, prev_read_vectors), dim=1) # Execute controller forward step. ctrl_output_BxH, ctrl_state_tuple = self.controller( controller_input, prev_ctrl_state_tuple) # Execute interface forward step. read_vectors_BxC_H, memory_BxAxC, interface_state_tuple = self.interface( ctrl_output_BxH, prev_memory_BxAxC, prev_interface_state_tuple) # Output layer - takes controller output concateneted with new read # vectors. read_vectors = torch.cat(read_vectors_BxC_H, dim=1) ext_hidden = torch.cat((ctrl_output_BxH, read_vectors), dim=1) logits_BxO = self.hidden2output(ext_hidden) # Pack current cell state. cell_state_tuple = NTMCellStateTuple( ctrl_state_tuple, interface_state_tuple, memory_BxAxC, read_vectors_BxC_H) # Return logits and current cell state. return logits_BxO, cell_state_tuple