#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
thalnet_module.py: defines a module in the ThalNet architecture"""
__author__ = "Younes Bouhadjar & Vincent Marois"
import torch
from torch.nn import Module
from miprometheus.utils.app_state import AppState
from miprometheus.models.controllers.controller_factory import ControllerFactory
[docs]class ThalnetModule(Module):
"""
Implements a ``ThalNet`` module.
"""
[docs] def __init__(self,
center_size,
context_size,
center_size_per_module,
input_size,
output_size):
"""
Constructor of the ``ThalnetModule``.
:param input_size: size of the input sequences
:type input_size: int
:param output_size: size of the produced output sequences
:type output_size: int
:param center_size: Size of the center of the model.
:type center_size: int
:param center_size_per_module: Size of the center slot allocated to each module.
:type center_size_per_module: int
"""
# call base constructor
super(ThalnetModule, self).__init__()
self.center_size = center_size
self.context_size = context_size
self.center_size_per_module = center_size_per_module
self.output_size = output_size
self.input_size = input_size
# Reading mechanism
self.fc_context = torch.nn.utils.weight_norm(
torch.nn.Linear(self.center_size, self.context_size), name='weight')
# Parameters needed for the controller
self.input_context_size = self.input_size + self.context_size
self.controller_hidden_size = self.output_size + self.center_size_per_module
self.controller_type = 'FFGRUController'
self.non_linearity = ''
# Set module
controller_params = {
"name": self.controller_type,
"input_size": self.input_context_size,
"output_size": self.controller_hidden_size,
"num_layers": 1,
"non_linearity": self.non_linearity,
"ff_output_size": center_size_per_module
}
self.controller = ControllerFactory.build(controller_params)
[docs] def init_state(self, batch_size):
"""
Initialize the state of a ``ThalNet`` module.
:param batch_size: batch size
:type batch_size: int
:return: center_state_per_module, tuple_controller_states
"""
dtype = AppState().dtype
# module state initialisation
tuple_controller_states = self.controller.init_state(batch_size)
# center state initialisation
center_state_per_module = torch.randn(
(batch_size, self.center_size_per_module)).type(dtype)
return center_state_per_module, tuple_controller_states
[docs] def forward(self, inputs, prev_center_state, prev_tuple_controller_state):
"""
Forward pass of a ``ThalnetModule``.
:param inputs: input sequences.
:type inputs: torch.tensor
:param prev_center_state: previous center state
:type prev_center_state: torch.tensor
:param prev_tuple_controller_state: previous tuple controller state
:type prev_tuple_controller_state: tuple
:return: output, center_feature_output, tuple_ctrl_state
"""
if inputs is not None:
if len(inputs.size()) <= 1 or len(inputs.size()) >= 4:
self.logger.error('The input size is not the one expected.')
raise SystemExit('The input size is not the one expected.')
if len(inputs.size()) == 3:
# inputs_size : [batch_size, num_channel, input_size]
# select channel
inputs = inputs[:, 0, :]
# get the context_input and the inputs of the module
context_input = self.fc_context(prev_center_state)
inputs = torch.cat((inputs, context_input),
dim=1) if self.input_size else context_input
# Apply the controller
module_state, tuple_ctrl_state = self.controller(
inputs, prev_tuple_controller_state)
output, center_feature_output = torch.split(
module_state, [self.output_size, self.center_size_per_module],
dim=1) if self.output_size else(
None, module_state)
return output, center_feature_output, tuple_ctrl_state