Source code for miprometheus.models.dwm.interface

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""interface.py: Controlls the reading and writing from memory with the various DWM attention mechanisms"""
__author__ = " Younes Bouhadjar, T.S Jayram"

import torch
import numpy as np
import logging
import collections

from miprometheus.models.dwm.tensor_utils import circular_conv, normalize
from miprometheus.models.dwm.memory import Memory
from miprometheus.utils.app_state import AppState

# Helper collection type.
_InterfaceStateTuple = collections.namedtuple(
    'InterfaceStateTuple', ('head_weight', 'snapshot_weight'))


[docs]class InterfaceStateTuple(_InterfaceStateTuple): """ Tuple used by interface for storing current/past interface information: head_weight and snapshot_weight. """ __slots__ = ()
logger = logging.getLogger('DWM_interface')
[docs]class Interface: """ Implementation of the interface of the DWM. """
[docs] def __init__(self, num_heads, is_cam, num_shift, M): """ Initialize Interface. :param num_heads: number of heads :param is_cam (boolean): are the heads allowed to use content addressing :param num_shift: number of shifts of heads. :param M: Number of slots per address in the memory bank. """ self.num_heads = num_heads self.M = M # Define a dictionary for attentional parameters self.is_cam = is_cam self.param_dict = {'s': num_shift, 'jd': 1, 'j': 3, 'γ': 1, 'erase': M, 'add': M} if self.is_cam: self.param_dict.update({'k': M, 'β': 1, 'g': 1}) # create the parameter lengths and store their cumulative sum lengths = np.fromiter(self.param_dict.values(), dtype=int) self.cum_lengths = np.cumsum( np.insert(lengths, 0, 0), dtype=int).tolist()
[docs] def init_state(self, memory_addresses_size, batch_size): """ Returns 'zero' (initial) state of Interface tuple. :param batch_size: Size of the batch in given iteraction/epoch. :param memory_addresses_size: size of the memory :returns: Initial state tuple - object of InterfaceStateTuple class: (head_weight_init, snapshot_weight_init) """ dtype = AppState().dtype # initial attention vector head_weight_init = torch.zeros( (batch_size, self.num_heads, memory_addresses_size)).type(dtype) head_weight_init[:, 0:self.num_heads, 0] = 1.0 # bookmark snapshot_weight_init = head_weight_init return InterfaceStateTuple(head_weight_init, snapshot_weight_init)
@property def read_size(self): """ Returns the size of the data read by all heads. :return: (num_head*content_size) """ return self.num_heads * self.M @property def update_size(self): """ Returns the total number of parameters output by the controller. :return: (num_heads*parameters_per_head) """ return self.num_heads * self.cum_lengths[-1]
[docs] def read(self, wt, mem): """ Returns the data read from memory. :param wt: head's weights [batch_size, num_heads, memory_addresses_size] :param mem: the memory content [batch_size, memory_content_size, memory_addresses_size] :return: the read data [batch_size, num_heads, memory_content_size] """ memory = Memory(mem) read_data = memory.attention_read(wt) # flatten the data_gen in the last 2 dimensions sz = read_data.size()[:-2] return read_data.view(*sz, self.read_size)
[docs] def update(self, update_data, tuple_interface_prev, mem): """ Erases from memory, writes to memory, updates the weights using various attention mechanisms. :param update_data: the parameters from the controllers :param tuple_interface_prev: contains (head_weight, snapshot_weight) :param tuple_interface_prev.head_weight: head attention [batch_size, num_heads, memory_size] :param tuple_interface_prev.snapshot_weight: snapshot(bookmark) attention [batch_size, num_heads, memory_size] :param mem: the memory [batch_size, content_size, memory_size] :returns: InterfaceTuple contains [head_weight, snapshot_weight]: the updated weight of head and snapshot :returns: mem: the new memory content """ wt_head_prev, wt_att_snapshot_prev = tuple_interface_prev assert update_data.size( )[-1] == self.update_size, "Mismatch in update sizes" # reshape update data_gen by heads and total parameter size sz = update_data.size()[:-1] update_data = update_data.view( *sz, self.num_heads, self.cum_lengths[-1]) # split the data_gen according to the different parameters data_splits = [ update_data [..., self.cum_lengths[i]: self.cum_lengths[i + 1]] for i in range(len(self.cum_lengths) - 1)] # Obtain update parameters if self.is_cam: s, jd, j, γ, erase, add, k, β, g = data_splits # Apply Activations # key vector used for content-based addressing k = torch.nn.functional.tanh(k) # key strength used for content-based addressing β = torch.nn.functional.softplus(β) g = torch.nn.functional.sigmoid(g) # interpolation gate else: s, jd, j, γ, erase, add = data_splits # shift weighting (determines how the weight is rotated) s = torch.nn.functional.softmax(torch.nn.functional.softplus(s), dim=-1) γ = 1 + torch.nn.functional.softplus(γ) # used for weight sharpening erase = torch.nn.functional.sigmoid(erase) # erase memory content # Write to memory memory = Memory(mem) memory.erase_weighted(erase, wt_head_prev) memory.add_weighted(add, wt_head_prev) # update attention # Set jumping mechanisms # fixed attention to address 0 wt_address_0 = torch.zeros_like(wt_head_prev) wt_address_0[:, :, 0] = 1 # interpolation between wt and wt_d jd = torch.nn.functional.sigmoid(jd) wt_att_snapshot = (1 - jd) * wt_head_prev + jd * wt_att_snapshot_prev # interpolation between wt_0 wt_d wt j = torch.nn.functional.softmax(j, dim=-1) j = j[:, :, None, :] wt_head = j[..., 0] * wt_head_prev \ + j[..., 1] * wt_att_snapshot \ + j[..., 2] * wt_address_0 # Move head according to content based addressing and shifting if self.is_cam: # content addressing ... wt_k = memory.content_similarity(k) # ... modulated by β wt_β = torch.nn.functional.softmax(β * wt_k, dim=-1) # scalar interpolation wt_head = g * wt_β + (1 - g) * wt_head # convolution with shift wt_s = circular_conv(wt_head, s) eps = 1e-12 wt_head = (wt_s + eps) ** γ # sharpening with normalization wt_head = normalize(wt_head) # check attention is invalid for head 0 check_wt = torch.max( torch.abs(torch.sum(wt_head[:, 0, :], dim=-1) - 1.0)) if check_wt > 1.0e-5: logger.warning("Warning: gamma very high, normalization problem") mem = memory.content return InterfaceStateTuple(wt_head, wt_att_snapshot), mem