Source code for miprometheus.models.dnc.control_and_params

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""control_and_params.py: Calls controller and parameter generators """
__author__ = " Ryan L. McAvoy"

import torch
from torch.nn import Module
from miprometheus.models.dnc.param_gen import Param_Generator

from miprometheus.models.controllers.controller_factory import ControllerFactory


[docs]class ControlParams(Module): """ """
[docs] def __init__(self, output_size, read_size, params): """ Initialize an Controller. :param output_size: output size. :param read_size: size of data_gen read from memory :param params: dictionary of input parameters """ super(ControlParams, self).__init__() self.read_size = read_size # Parse parameters. # Set input and hidden dimensions. self.input_size = params["input_item_size"] ctrl_in_dim = self.input_size + self.read_size self.hidden_state_size = params['hidden_state_size'] # Get memory parameters. self.num_memory_bits = params['memory_content_size'] self.controller_type = params['controller_type'] self.shift_size = params['shift_size'] self.num_reads = params['num_reads'] self.num_writes = params['num_writes'] self.non_linearity = params['non_linearity'] # TODO Make Multilayered LSTM controller in the vein of the DNC paper # State layer controller_params = { "name": self.controller_type, "input_size": ctrl_in_dim, "output_size": self.hidden_state_size, "num_layers": 1, "non_linearity": self.non_linearity } self.state_gen = ControllerFactory.build(controller_params) self.output_gen = torch.nn.Linear(self.hidden_state_size, output_size) # Update layer self.param_gen = Param_Generator( self.hidden_state_size, word_size=self.num_memory_bits, num_reads=self.num_reads, num_writes=self.num_writes, shift_size=self.shift_size)
[docs] def init_state(self, batch_size): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of LSTMStateTuple class. """ return self.state_gen.init_state(batch_size)
[docs] def forward(self, inputs, prev_ctrl_state_tuple, read_data): """ Calculates the output, the hidden state and the controller parameters. :param inputs: Current input (from time t) [BATCH_SIZE x INPUT_SIZE] :param read_data: data read from memory (from time t-1) [BATCH_SIZE x num_data_bits] :param prev_ctrl_state_tuple: Tuple of states of controller (from time t-1) :return: Tuple [output, hidden_state, update_data] (update_data contains all of the controller parameters) """ # Concatenate the 2 inputs to controller combined = torch.cat((inputs, read_data), dim=-1) hidden_state, ctrl_state_tuple = self.state_gen( combined, prev_ctrl_state_tuple) output = self.output_gen(hidden_state) update_data = self.param_gen(hidden_state) return output, ctrl_state_tuple, update_data