Source code for miprometheus.models.dwm.controller

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""controller.py: Calls DWM controller """
__author__ = "Younes Bouhadjar"

import torch
from torch.nn import Module

from miprometheus.models.controllers.controller_factory import ControllerFactory


[docs]class Controller(Module): """ Implementation of the DWM controller. """
[docs] def __init__(self, in_dim, output_units, state_units, read_size, update_size): """ Constructor for the Controller. :param in_dim: input size. :param output_units: output size. :param state_units: state size. :param read_size: size of data_gen read from memory :param update_size: total number of parameters for updating attention and memory """ super(Controller, self).__init__() self.read_size = read_size self.update_size = update_size self.state_units = state_units self.ctrl_in_dim = in_dim + self.read_size self.ctrl_in_state_dim = in_dim + state_units + self.read_size # Output layer self.output_units = output_units # State layer dictionary self.controller_type = 'RNNController' self.non_linearity = 'sigmoid' controller_params = { "name": self.controller_type, "input_size": self.ctrl_in_dim, "output_size": self.state_units, "num_layers": 1, "non_linearity": self.non_linearity } # State layer self.i2s = ControllerFactory.build(controller_params) # Update layer self.i2u = torch.nn.Linear(self.ctrl_in_state_dim, self.update_size) # Output layer self.i2o = torch.nn.Linear(self.ctrl_in_state_dim, self.output_units)
[docs] def init_state(self, batch_size): """ Returns 'zero' (initial) state tuple. :param batch_size: size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of LSTMStateTuple class. """ return self.i2s.init_state(batch_size)
[docs] def forward(self, input, tuple_state_prev, read_data): """ Forward pass of the DWM controller, calculates the output, the hidden state and the interface parameters. :param input: current input (from time t) [batch_size, in_dim] :param tuple_state_prev: contains previous hidden state (from time t-1) [batch_size, state_units] :param read_data: read data from memory (from time t) [batch_size, read_size] :returns: output: logits represent the prediction [batch_size, output_units] :returns: tuple_state: contains new_hidden_state :returns: update_data: interface parameters [batch_size, update_size] """ # Concatenate the 3 inputs to controller combined = torch.cat((input, read_data), dim=-1) combined_with_state = torch.cat( (combined, tuple_state_prev.hidden_state), dim=-1) # Get the state and update; no activation is applied state, tuple_state = self.i2s(combined, tuple_state_prev) # Get output with activation output = self.i2o(combined_with_state) # update attentional parameters and memory update parameters update_data = self.i2u(combined_with_state) return output, tuple_state, update_data