Source code for miprometheus.models.controllers.controller_factory

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
controller_factory.py: Factory building controllers for MANNs.

"""
__author__ = "Ryan L. McAvoy & Vincent Marois"
import logging
import inspect
from torch.nn import Module

import miprometheus.models.controllers


[docs]class ControllerFactory(object): """ Class returning concrete controller depending on the name provided in the \ list of parameters. """
[docs] @staticmethod def build(params): """ Static method returning particular controller, depending on the name \ provided in the list of parameters. :param params: Parameters used to instantiate the controller. :type params: ``utils.param_interface.ParamInterface`` ..note:: ``params`` should contains the exact (case-sensitive) class name of the controller to instantiate. :return: Instance of a given controller. """ logging.basicConfig(level=logging.INFO) logger = logging.getLogger('ControllerFactory') # Check presence of the name attribute. if 'name' not in params: logger.error("Controller configuration section does not contain the key 'name'") exit(-1) # Get the class name. name = params['name'] # Verify that the specified class is in the controller package. if name not in dir(miprometheus.models.controllers): logger.error("Could not find the specified class '{}' in the controllers package.".format(name)) exit(-1) # Get the actual class. controller_class = getattr(miprometheus.models.controllers, name) # Check if class is derived (even indirectly) from nn.Module. inherits = False for c in inspect.getmro(controller_class): if c.__name__ == Module.__name__: inherits = True break if not inherits: logger.error("The specified class '{}' is not derived from the nn.Module class".format(name)) exit(-1) # Ok, proceed. logger.info('Loading the {} controller from {}'.format(name, controller_class.__module__)) # return the instantiated controller class return controller_class(params)
if __name__ == "__main__": """ Tests ControllerFactory. """ from miprometheus.utils.param_interface import ParamInterface controller_params = ParamInterface() controller_params.add_default_params({'name': 'RNNController', 'input_size': 11, 'output_size': 11, 'hidden_state_size': 20, 'num_layers': 1, 'non_linearity': 'sigmoid'}) controller = ControllerFactory.build_controller(controller_params) print(type(controller))