Source code for miprometheus.models.encoder_solver.mae_cell

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ntm_cell.py: pytorch module implementing single (recurrent) cell of Neural Turing Machine"""
__author__ = "Tomasz Kornuta"

import torch
import collections
from torch.nn import Module


# Set logging level.
import logging
logger = logging.getLogger('MAE-Cell')
# logging.basicConfig(level=logging.DEBUG)

import os

from miprometheus.models.controllers.controller_factory import ControllerFactory


from miprometheus.models.encoder_solver.mae_interface import MAEInterface

# Helper collection type.
_MAECellStateTuple = collections.namedtuple(
    'MAECellStateTuple', ('ctrl_state', 'interface_state', 'memory_state'))


[docs]class MAECellStateTuple(_MAECellStateTuple): """ Tuple used by MAE Cells for storing current/past state information. """ __slots__ = ()
[docs]class MAECell(Module): """ Class representing a single Memory-Augmented Encoder cell. """
[docs] def __init__(self, params): """ Cell constructor. Cell creates controller and interface. It also initializes memory "block" that will be passed between states. :param params: Dictionary of parameters. """ # Call constructor of base class. super(MAECell, self).__init__() # Parse parameters. # Set input and output sizes. self.input_size = params['input_item_size'] self.output_size = params['output_item_size'] # Get controller hidden state size. self.controller_hidden_state_size = params['controller'][ 'hidden_state_size'] # Controller - entity that processes input and produces hidden state of # the ntm cell. ext_controller_inputs_size = self.input_size # Create dictionary wirh controller parameters. controller_params = { "name": params['controller']['name'], "input_size": ext_controller_inputs_size, "output_size": self.controller_hidden_state_size, "non_linearity": params['controller']['non_linearity'], "num_layers": params['controller']['num_layers'] } # Build the controller. self.controller = ControllerFactory.build(controller_params) # Interface - entity responsible for accessing the memory. self.interface = MAEInterface(params) # Layer that produces output on the basis of... hidden state? ext_hidden_size = self.controller_hidden_state_size self.hidden2output = torch.nn.Linear(ext_hidden_size, self.output_size)
def load(self, filename): # Check filename. if os.path.isfile(filename): # Load checkpoint from filename. chkpt = torch.load( filename, map_location=lambda storage, loc: storage) # Load controller and interface self.controller.load_state_dict(chkpt['ctrl_dict']) self.interface.load_state_dict(chkpt['interface_dict']) logger.info( "Imported {} parameters from checkpoint (episode {}, loss {}) from file {}".format( chkpt['name'], chkpt['stats']['episode'], chkpt['stats']['loss'], filename)) else: logger.error("Encoder checkpoint not found at {}".format(filename))
[docs] def save(self, model_dir, stat_obj, is_best_model, save_intermediate): """ Method saves the model and encoder to file. :param model_dir: Directory where the model will be saved. :param stat_obj: Statistics object (collector or aggregator) that contain current loss and episode number (and other statistics). :param is_best_model: Flag indicating whether it is the best model or not. :parma save_intermediate: Flag indicating whether intermediate models should be saved or not. """ episode = stat_obj['episode'] # Checkpoint to be saved. chkpt = { 'name': 'MAE controller and interface', 'ctrl_dict': self.controller.state_dict(), 'interface_dict': self.interface.state_dict(), 'stats': stat_obj.export_to_checkpoint() } # Save the intermediate checkpoint. if save_intermediate: # Generate filename pt. filename = model_dir + 'encoder_episode_{:05d}.pt'.format(episode) # Save dictionary to file. torch.save(chkpt, filename) logger.info( "Encoder and statistics exported to checkpoint {}".format( filename)) # Save the best model. if is_best_model: # Generate filename pt. filename = model_dir + 'encoder_best.pt' # Save dictionary to file. torch.save(chkpt, filename) logger.info( "Encoder and statistics exported to checkpoint {}".format( filename))
[docs] def freeze(self): """ Freezes the trainable weigths. """ # Freeze controller. for param in self.controller.parameters(): param.requires_grad = False logger.info("Encoder controller is frozen") # Freeze interface. self.interface.freeze() logger.info("Encoder interface is frozen")
# Freeze output layer. # for param in self.hidden2output.parameters(): # param.requires_grad = False
[docs] def init_state(self, init_memory_BxAxC): """ Initializes state of MAE cell. Recursively initialization: controller, interface. :param init_memory_BxAxC: Initial memory state [BATCH_SIZE x MEMORY_ADDRESSES x MEMORY_CONTENT]. :returns: Initial state tuple - object of NTMCellStateTuple class. """ # Get number of memory addresses. batch_size = init_memory_BxAxC.size(0) num_memory_addresses = init_memory_BxAxC.size(1) # Initialize controller state. ctrl_init_state = self.controller.init_state(batch_size) # Initialize interface state. interface_init_state = self.interface.init_state( batch_size, num_memory_addresses) # Pack and return a tuple. return MAECellStateTuple( ctrl_init_state, interface_init_state, init_memory_BxAxC)
[docs] def forward(self, inputs_BxI, prev_cell_state): """ Forward function of NTM cell. :param inputs_BxI: a Tensor of input data of size [BATCH_SIZE x INPUT_SIZE] :param prev_cell_state: a MAECellStateTuple tuple, containing previous state of the cell. :returns: MAECellStateTuple tuple containing current cell state. """ # Unpack previous cell state. (prev_ctrl_state_tuple, prev_interface_state_tuple, prev_memory_BxAxC) = prev_cell_state controller_input = inputs_BxI # Execute controller forward step. ctrl_output_BxH, ctrl_state_tuple = self.controller( controller_input, prev_ctrl_state_tuple) # Execute interface forward step. memory_BxAxC, interface_state_tuple = self.interface( ctrl_output_BxH, prev_memory_BxAxC, prev_interface_state_tuple) # Output layer - takes controller hidden state. logits_BxO = self.hidden2output(ctrl_output_BxH) # Pack current cell state. cell_state_tuple = MAECellStateTuple( ctrl_state_tuple, interface_state_tuple, memory_BxAxC) # Return logits and current cell state. return logits_BxO, cell_state_tuple