#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""interface.py: Controlls the reading and writing from memory with the various DNC attention mechanisms"""
__author__ = " Ryan L. McAvoy"
import torch
import collections
from miprometheus.models.dnc.tensor_utils import circular_conv, normalize
from miprometheus.models.dnc.memory import Memory
from miprometheus.models.dnc.memory_usage import MemoryUsage
from miprometheus.models.dnc.temporal_linkage import TemporalLinkage
from miprometheus.utils.app_state import AppState
# Helper collection type.
_InterfaceStateTuple = collections.namedtuple(
'InterfaceStateTuple', ('read_weights', 'write_weights', 'usage', 'links'))
[docs]class InterfaceStateTuple(_InterfaceStateTuple):
"""
Tuple used by interface for storing current/past state information.
"""
__slots__ = ()
[docs]class Interface(object):
"""
"""
[docs] def __init__(self, params):
"""
Initialize Interface.
:param params: dictionary of input parameters
"""
# Get memory parameters.
self.num_memory_bits = params['memory_content_size']
# Number of read and write heads
self._num_reads = params["num_reads"]
self._num_writes = params["num_writes"]
# parameters that determine whether this acts as a DNC or NTM
self.use_ntm_write = params['use_ntm_write']
self.use_ntm_read = params['use_ntm_read']
self.use_ntm_order = params['use_ntm_order']
self.use_extra_write_gate = params['use_extra_write_gate']
self.mem_usage = MemoryUsage()
self.temporal_linkage = TemporalLinkage(self._num_writes)
@property
def read_size(self):
"""
Returns the size of the data read by all heads.
:return: (num_head*content_size)
"""
return self._num_reads * self.num_memory_bits
[docs] def read(self, prev_interface_tuple, mem):
"""
returns the data read from memory.
:param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
:param mem: the memory [batch_size, content_size, memory_size]
:return: the read data [batch_size, content_size]
"""
(wt, _, _, _) = prev_interface_tuple
memory = Memory(mem)
read_data = memory.attention_read(wt)
# flatten the data_gen in the last 2 dimensions
sz = read_data.size()[:-2]
return read_data.view(*sz, self.read_size)
[docs] def edit_memory(self, interface_tuple, update_data, mem):
"""
Edits the external memory and then returns it.
:param update_data: the parameters from the controllers [dictionary]
:param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
:param mem: the memory [batch_size, content_size, memory_size]
:return: edited memory [batch_size, content_size, memory_size]
"""
(_, write_attention, _, _) = interface_tuple
# Write to memory
write_gate = update_data['write_gate']
add = update_data['write_vectors']
erase = update_data['erase_vectors']
if self.use_extra_write_gate:
add = add * write_gate
erase = erase * write_gate
memory = Memory(mem)
memory.erase_weighted(erase, write_attention)
memory.add_weighted(add, write_attention)
mem = memory.content
return mem
[docs] def init_state(self, memory_address_size, batch_size):
"""
Returns 'zero' (initial) state tuple.
:param memory_address_size: The number of memory addresses
:param batch_size: Size of the batch in given iteraction/epoch.
:returns: Initial state tuple - object of InterfaceStateTuple class.
"""
dtype = AppState().dtype
# Read attention weights [BATCH_SIZE x MEMORY_SIZE]
read_attention = torch.ones(
(batch_size, self._num_reads, memory_address_size)).type(dtype) * 1e-6
# Write attention weights [BATCH_SIZE x MEMORY_SIZE]
write_attention = torch.ones(
(batch_size, self._num_writes, memory_address_size)).type(dtype) * 1e-6
# Usage of memory cells [BATCH_SIZE x MEMORY_SIZE]
usage = self.mem_usage.init_state(memory_address_size, batch_size)
# temporal links tuple
link_tuple = self.temporal_linkage.init_state(
memory_address_size, batch_size)
return InterfaceStateTuple(
read_attention, write_attention, usage, link_tuple)
[docs] def update_weight(self, prev_attention, memory,
strength, gate, key, shift, sharp):
"""
Update the attention with NTM's mix of content addressing and linear
shifting.
:param prev_attention: tensor of shape `[batch_size, num_writes,
memory_size]` giving the attention at the previous time step.
:param memory: the memory of the previous step (class)
:param strength: The strengthening parameter for the content addressing [batch, num_heads, 1]
:param gate: The interpolation gate between the content addressing and the previous weight [batch, num_heads, 1]
:param key: The comparison key for the content addressing [batch, num_heads, num_memory_bits]
:param shift: The shift vector that defines the circular convolution of the outputs [batch, num_heads, num_shifts]
:param sharp: sharpening parameter for the attention [batch, num_heads, 1]
"""
# Content addressing using weighted cosine similarity
similarity = memory.content_similarity(key)
content_weights = torch.nn.functional.softmax(strength * similarity, dim=-1)
# Gate between the current weight and the content weights
attention = gate * content_weights + (1 - gate) * prev_attention
# Linear shift with convolution
shifted_attention = circular_conv(attention, shift)
# Sharpen weights and then normalize
eps = 1e-12
attention = (shifted_attention + eps) ** sharp
attention = normalize(attention)
return attention
[docs] def update_write_weight(self, usage, memory,
allocation_gate, write_gate, key, strength):
"""
Update write attention with DNC's combination of content addressing and
usage based allocation.
:param usage: A tensor of shape `[batch_size, memory_size]` representing
current memory usage.
:param memory: the memory of the previous step (class)
:param strength: The strengthening parameter for the content addressing [batch, num_writes, 1]
:param key: The comparison key for the content addressing [batch, num_writes, num_memory_bits]
:param allocation_gate: Interpolation between writing to unallocated memory and
content-based lookup, for each write head [batch, num_writes, 1]
:param write_gate: Overall gating of write amount for each write head. [batch, num_writes, 1]
"""
# Calculate which memory slots are open for allocation
write_allocation_weights = self.mem_usage.write_allocation_weights(
usage=usage,
write_gates=(allocation_gate * write_gate),
num_writes=self._num_writes)
# Content addressing using weighted cosine similarity
similarity = memory.content_similarity(key)
content_weights = torch.nn.functional.softmax(strength * similarity, dim=-1)
# Gate between the allocatable memory and the content weighted memory
wt = write_gate * (allocation_gate * write_allocation_weights +
(1 - allocation_gate) * content_weights)
return wt
[docs] def update_read_weight(
self, link, memory, prev_read_weights, read_mode, key, strength):
"""
Update the read attention with the DNC's combination of content
addressing and temporal link propagation to go forwards or backwards in
time.
:param link: A tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` representing the previous link graphs for each write
head.
:param memory: the memory of the previous step (class)
:param prev_read_weights: tensor of shape `[batch_size, num_reads,
memory_size]` containing the previous read weights w_{t-1}^r.
:param read_mode: Mixing between "backwards" and "forwards" positions (for each write head)
and content-based lookup, for each read head [batch, num_reads, 1+2*numwrites]
:param strength: The strengthening parameter for the content addressing [batch, num_reads, 1]
:param key: The comparison key for the content addressing [batch, num_reads, num_memory_bits]
"""
# Content addressing using weighted cosine similarity
similarity = memory.content_similarity(key)
content_weights = torch.nn.functional.softmax(strength * similarity, dim=-1)
# Calculate the weight to go forward and backwards along the links
# matrix
forward_weights = self.temporal_linkage.directional_read_weights(
link.link, prev_read_weights, forward=True)
backward_weights = self.temporal_linkage.directional_read_weights(
link.link, prev_read_weights, forward=False)
# Reshape the read mode matrix
backward_mode = torch.unsqueeze(read_mode[:, :, :self._num_writes], 3)
forward_mode = torch.unsqueeze(
read_mode[:, :, self._num_writes:2 * self._num_writes], 3)
content_mode = torch.unsqueeze(
read_mode[:, :, 2 * self._num_writes], 2)
# Gate between the content similarity, going forwards along the link
# matrix and going backwards
read_weights = (content_mode * content_weights +
torch.sum(forward_mode * forward_weights, 2) +
torch.sum(backward_mode * backward_weights, 2))
return read_weights
[docs] def update_read(self, update_data, prev_interface_tuple, mem):
"""
Updates the read attention switching between the NTM and DNC
mechanisms.
:param update_data: the parameters from the controllers [dictionary]
:param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links[
:param prev_memory_BxMxA: the memory of the previous step (class)
:return: The new interface tuple with an updated usage and write attention
"""
(prev_read_attention, prev_write_attention,
prev_usage, prev_links) = prev_interface_tuple
# Parameters for the content addressing
key = update_data['read_content_keys']
strength = update_data['read_content_strengths']
# retrieve memory Class
memory = Memory(mem)
# update the attention using either the NTM read mechanism (True) or
# the DNC (False)
if self.use_ntm_read:
# Parameters for shift addressing
shift = update_data['shifts_read']
sharp = update_data['sharpening_read']
gate = update_data['read_mode_shift']
read_attention = self.update_weight(
prev_read_attention, memory, strength, gate, key, shift, sharp)
links = prev_links
else:
read_mode = update_data['read_mode']
links = self.temporal_linkage.calc_temporal_links(
prev_write_attention, prev_links)
read_attention = self.update_read_weight(
links, memory, prev_read_attention, read_mode, key, strength)
interface_state_tuple = InterfaceStateTuple(
read_attention, prev_write_attention, prev_usage, links)
return interface_state_tuple
[docs] def update_write(self, update_data, prev_interface_tuple, mem):
"""
Updates the write attention switching between the NTM and DNC
mechanisms.
:param update_data: the parameters from the controllers [dictionary]
:param prev_interface_tuple: Tuple [previous read, previous write, prev usage, prev links]
:param prev_memory_BxMxA: the memory of the previous step (class)
:return: The new interface tuple with an updated usage and write attention
"""
(prev_read_attention, prev_write_attention,
prev_usage, prev_links) = prev_interface_tuple
# Obtain update parameters
key = update_data['write_content_keys']
strength = update_data['write_content_strengths']
gate = update_data['allocation_gate']
# retrieve memory Class
memory = Memory(mem)
free_gate = update_data['free_gate']
usage = self.mem_usage.calculate_usage(
prev_write_attention, free_gate, prev_read_attention, prev_usage)
# update the attention using either the NTM write mechanism (True) or
# the DNC (False)
if self.use_ntm_write:
# Parameters for shift addressing
shift = update_data['shifts']
sharp = update_data['sharpening']
write_attention = self.update_weight(
prev_write_attention, memory, strength, gate, key, shift, sharp)
else:
write_gate = update_data['write_gate']
allocation_gate = gate
write_attention = self.update_write_weight(
usage, memory, allocation_gate, write_gate, key, strength)
interface_state_tuple = InterfaceStateTuple(
prev_read_attention, write_attention, usage, prev_links)
return interface_state_tuple
[docs] def update_and_edit(self, update_data,
prev_interface_tuple, prev_memory_BxMxA):
"""
Erases from memory, writes to memory, updates the weights using various
attention mechanisms.
:param update_data: the parameters from the controllers [update_size]
:param prev_interface_tuple: the read weight [BATCH_SIZE, MEMORY_SIZE]
:param prev_memory_BxMxA: the memory of the previous step (class)
:return: the new read vector, the update memory, the new interface tuple
"""
(prev_read_attention, prev_write_attention,
prev_usage, prev_links) = prev_interface_tuple
# Step 1: update the write weights
interface_tuple = self.update_write(
update_data, prev_interface_tuple, prev_memory_BxMxA)
# Step 2: Write and Erase Data
memory_BxMxA = self.edit_memory(
interface_tuple, update_data, prev_memory_BxMxA)
# Step 3: Update read weights using either the current or previous
# memory
if self.use_ntm_order:
read_memory_BxMxA = prev_memory_BxMxA
else:
read_memory_BxMxA = memory_BxMxA
interface_tuple = self.update_read(
update_data, interface_tuple, read_memory_BxMxA)
# Step 4: Read the data from memory
read_vector_BxM = self.read(interface_tuple, memory_BxMxA)
return read_vector_BxM, memory_BxMxA, interface_tuple