Source code for miprometheus.models.dwm.dwm_cell

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""dwm_cell.py: The cell of the DWM. It operates on a single word"""
__author__ = "Younes Bouhadjar"

import torch
import collections
from torch.nn import Module

from miprometheus.models.dwm.controller import Controller
from miprometheus.models.dwm.interface import Interface

from miprometheus.utils.app_state import AppState

# Helper collection type.
_DWMCellStateTuple = collections.namedtuple(
    'DWMStateTuple', ('ctrl_state', 'interface_state', 'memory_state'))


[docs]class DWMCellStateTuple(_DWMCellStateTuple): """ Tuple used by DWM Cells for storing current/past state information: controller state, interface state, memory state. """ __slots__ = ()
[docs]class DWMCell(Module): """ Applies the DWM cell to an element in the input sequence. """
[docs] def __init__(self, in_dim, output_units, state_units, num_heads, is_cam, num_shift, M): """ Builds the DWM cell. :param in_dim: input size. :param output_units: output size. :param state_units: state size. :param num_heads: number of heads. :param is_cam: is it content_address able. :param num_shift: number of shifts of heads. :param M: Number of slots per address in the memory bank. """ super(DWMCell, self).__init__() self.num_heads = num_heads self.M = M # build the interface and controller self.interface = Interface(num_heads, is_cam, num_shift, M) self.controller = Controller( in_dim, output_units, state_units, self.interface.read_size, self.interface.update_size)
def init_state(self, memory_addresses_size, batch_size): dtype = AppState().dtype # Initialize controller state. tuple_ctrl_init_state = self.controller.init_state(batch_size) # Initialize interface state. tuple_interface_init_state = self.interface.init_state( memory_addresses_size, batch_size) # Initialize memory mem_init = (torch.ones( (batch_size, self.M, memory_addresses_size)) * 0.01).type(dtype) return DWMCellStateTuple( tuple_ctrl_init_state, tuple_interface_init_state, mem_init)
[docs] def forward(self, input, tuple_cell_state_prev): """ forward pass of the DWM_Cell. :param input: current input (from time t) [batch_size, inputs_size] :param tuple_cell_state_prev: contains (tuple_ctrl_state_prev, tuple_interface_prev, mem_prev), object of class DWMCellStateTuple :return: output: logits [batch_size, output_size] :return: tuple_cell_state: contains (tuple_ctrl_state, tuple_interface, mem) .. math:: step1: read memory r_t &= M_t \otimes w_t step2: controller h_t &= \sigma(W_h[x_t,h_{t-1},r_{t-1}]) y_t &= W_{y}[x_t,h_{t-1},r_{t-1}] P_t &= W_{P}[x_t,h_{t-1},r_{t-1}] step3: memory update M_t &= M_{t-1}\circ (E-w_t \otimes e_t)+w_t\otimes a_t to be completed ... """ tuple_ctrl_state_prev, tuple_interface_prev, mem_prev = tuple_cell_state_prev # step1: read from memory using attention read_data = self.interface.read( tuple_interface_prev.head_weight, mem_prev) # step2: controller output, tuple_ctrl_state, update_data = self.controller( input, tuple_ctrl_state_prev, read_data) # step3: update memory and attention tuple_interface, mem = self.interface.update( update_data, tuple_interface_prev, mem_prev) tuple_cell_state = DWMCellStateTuple( tuple_ctrl_state, tuple_interface, mem) return output, tuple_cell_state