Source code for miprometheus.models.thalnet.thalnet_cell

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
thalnet_cell: A ThalNetCell, constituted of ThalNet modules. It operates on a single word.
"""
__author__ = "Younes Bouhadjar & Vincent Marois"

import torch
from torch.nn import Module
from miprometheus.models.thalnet.thalnet_module import ThalnetModule


[docs]class ThalNetCell(Module): """ Implementation of the ``ThalNetCell``, iterating over one sequence element at a time. It is constituted of several ``ThalNetModule``. """
[docs] def __init__(self, input_size: int, output_size: int, context_input_size: int, center_size_per_module: int, num_modules: int): """ Constructor of the ``ThalNetCell`` class. :param input_size: size of the input sequences :type input_size: int :param output_size: size of the produced output sequences :type output_size: int :param context_input_size: context input size :type context_input_size: int :param center_size_per_module: Size of the center slot allocated to each module. :type center_size_per_module: int :param num_modules: number of modules to constitute the cell. :type num_modules: int """ # Call base class inits here. super(ThalNetCell, self).__init__() self.context_input_size = context_input_size self.input_size = input_size self.output_size = output_size self.center_size = num_modules * center_size_per_module self.center_size_per_module = center_size_per_module self.num_modules = num_modules # init module-center cell self.modules_thalnet = torch.nn.ModuleList() self.modules_thalnet.append( ThalnetModule( center_size=self.center_size, context_size=self.context_input_size, center_size_per_module=self.center_size_per_module, input_size=self.input_size, output_size=0)) self.modules_thalnet.extend( [ ThalnetModule( center_size=self.center_size, context_size=self.context_input_size, center_size_per_module=self.center_size_per_module, input_size=0, output_size=self.output_size if i == self.num_modules - 1 else 0) for i in range( 1, self.num_modules)])
[docs] def init_state(self, batch_size): """ Initialize the state of ``ThalNet``. :param batch_size: batch size :type batch_size: int :return: Initialized states of the ThalNet cell. """ # module and center state initialisation states = [self.modules_thalnet[i].init_state( batch_size) for i in range(self.num_modules)] return states
[docs] def forward(self, inputs, prev_state): """ forward run of the ``ThalNetCell``. :param inputs: inputs at time t, [batch_size, input_size] :type inputs: torch.tensor :param prev_state: previous state [batch_size, state_size] :type prev_state: torch.tensor :return: - states [batch_size, state_size] - prediction [batch_size, output_size] """ prev_center_states = [prev_state[i][0] for i in range(self.num_modules)] prev_controller_states = [prev_state[i][1] for i in range(self.num_modules)] # Concatenate all the centers prev_center_states = torch.cat(prev_center_states, dim=1) states = [] # run the different modules, they share all the same center for module, prev_controller_state in zip( self.modules_thalnet, prev_controller_states): output, center_feature, module_state = module( inputs, prev_center_states, prev_controller_state) states.append((center_feature, module_state)) return output, states