#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""param_gen.py: The class that converts the hidden vector of the controller into the parameters of the interface
"""
__author__ = " Ryan L. McAvoy"
import torch
from torch.nn import Module
[docs]class Param_Generator(Module):
"""
"""
[docs] def __init__(self,
param_in_dim,
# memory_size=128,
word_size=20,
num_reads=1,
num_writes=1,
shift_size=3):
"""
Initialize all the parameters of the interface.
:param param_in_dim: input size. (typically the size of the hidden state)
:param word_size: size of the word in memory
:param num_reads: number of read heads
:param num_writes: number of write heads
:param shift_size: size of the shift vector (3 means it can go forward, backward and remain in place)
"""
super(Param_Generator, self).__init__()
#self._memory_size = memory_size
self._word_size = word_size
self._num_reads = num_reads
self._num_writes = num_writes
self._num_shifts = shift_size
# v_t^i - The vectors to write to memory, for each write head `i`.
self.write_vect_ = torch.nn.Linear(
param_in_dim, self._num_writes * self._word_size)
# e_t^i - Amount to erase the memory by before writing, for each write
# head.
self.erase_vect_ = torch.nn.Linear(
param_in_dim, self._num_writes * self._word_size)
# f_t^j - Amount that the memory at the locations read from at the previous
# time step can be declared unused, for each read head `j`.
self.free_gate_ = torch.nn.Linear(param_in_dim, self._num_reads)
# g_t^{a, i} - Interpolation between writing to unallocated memory and
# content-based lookup, for each write head `i`. Note: `a` is simply used to
# identify this gate with allocation vs writing (as defined below).
self.allocate_gate_ = torch.nn.Linear(param_in_dim, self._num_writes)
# g_t^{w, i} - Overall gating of write amount for each write head.
self.write_gate_ = torch.nn.Linear(param_in_dim, self._num_writes)
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
# each write head), and content-based lookup, for each read head.
num_read_modes = 1 + 2 * self._num_writes
self.read_mode_ = torch.nn.Linear(
param_in_dim, self._num_reads * num_read_modes)
# Parameters for the (read / write) "weights by content matching"
# modules.
self.write_keys_ = torch.nn.Linear(
param_in_dim, self._num_writes * self._word_size)
self.write_strengths_ = torch.nn.Linear(param_in_dim, self._num_writes)
self.read_keys_ = torch.nn.Linear(
param_in_dim, self._num_reads * self._word_size)
self.read_strengths_ = torch.nn.Linear(param_in_dim, self._num_reads)
# s_j The shift vector that defines the circular convolution of the
# outputs
self.shifts_ = torch.nn.Linear(
param_in_dim, self._num_shifts * self._num_writes)
# \gamma, sharpening parameter for the weights
self.sharpening_ = torch.nn.Linear(param_in_dim, self._num_writes)
self.sharpening_r_ = torch.nn.Linear(param_in_dim, self._num_reads)
self.shifts_r_ = torch.nn.Linear(
param_in_dim, self._num_shifts * self._num_reads)
[docs] def forward(self, vals):
"""
Calculates the controller parameters.
:param vals: data from the controller (from time t). Typically, the hidden state. [BATCH_SIZE x INPUT_SIZE]
:return update_data: dictionary (update_data contains all of the controller parameters)
"""
update_data = {}
# v_t^i - The vectors to write to memory, for each write head `i`.
update_data['write_vectors'] = self.write_vect_(
vals).view(-1, self._num_writes, self._word_size)
# e_t^i - Amount to erase the memory by before writing, for each write head.
# [batch, num_writes*word_size]
erase_vec = torch.nn.functional.sigmoid(self.erase_vect_(vals))
update_data['erase_vectors'] = erase_vec.view(
-1, self._num_writes, self._word_size)
# f_t^j - Amount that the memory at the locations read from at the previous
# time step can be declared unused, for each read head `j`.
update_data['free_gate'] = torch.nn.functional.sigmoid(
self.free_gate_(vals)).view(-1, self._num_reads, 1)
# g_t^{a, i} - Interpolation between writing to unallocated memory and
# content-based lookup, for each write head `i`. Note: `a` is simply used to
# identify this gate with allocation vs writing (as defined below).
update_data['allocation_gate'] = torch.nn.functional.sigmoid(
self.allocate_gate_(vals)).view(-1, self._num_writes, 1)
# g_t^{w, i} - Overall gating of write amount for each write head.
update_data['write_gate'] = torch.nn.functional.sigmoid(
self.write_gate_(vals)).view(-1, self._num_writes, 1)
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
# each write head), and content-based lookup, for each read head.
# Need to apply softmax batch-wise to the second index. This will not
# work
num_read_modes = 1 + 2 * self._num_writes
read_mode = torch.nn.functional.softmax(self.read_mode_(vals), -1)
update_data['read_mode'] = read_mode.view(
-1, self._num_reads, num_read_modes)
# Parameters for the (read / write) "weights by content matching"
# modules.
update_data['write_content_keys'] = self.write_keys_(
vals).view(-1, self._num_writes, self._word_size)
update_data['write_content_strengths'] = 1 + \
torch.nn.functional.softplus(self.write_strengths_(vals)
).view(-1, self._num_writes, 1)
update_data['read_content_keys'] = self.read_keys_(
vals).view(-1, self._num_reads, self._word_size)
update_data['read_content_strengths'] = 1 + \
torch.nn.functional.softplus(self.read_strengths_(vals)).view(-1, self._num_reads, 1)
# s_j The shift vector that defines the circular convolution of the
# outputs
shifts = torch.nn.functional.softmax(torch.nn.functional.softplus(self.shifts_(vals)), dim=-1)
update_data['shifts'] = shifts.view(-1,
self._num_writes, self._num_shifts)
shifts_r = torch.nn.functional.softmax(torch.nn.functional.softplus(self.shifts_r_(vals)), dim=-1)
update_data['shifts_read'] = shifts_r.view(
-1, self._num_reads, self._num_shifts)
# \gamma, sharpening parameter for the weights
update_data['sharpening'] = 1 + \
torch.nn.functional.softplus(self.sharpening_(vals)).view(-1, self._num_writes, 1)
update_data['sharpening_read'] = 1 + \
torch.nn.functional.softplus(self.sharpening_r_(vals)).view(-1, self._num_reads, 1)
update_data['read_mode_shift'] = torch.nn.functional.sigmoid(
self.free_gate_(vals)).view(-1, self._num_reads, 1)
return update_data