Source code for miprometheus.workers.tester

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) IBM Corporation 2018
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.


    - This file sets hosts a function which adds specific arguments a tester will need.
    - Also defines the ``Tester()`` class.

__author__ = "Vincent Marois, Tomasz Kornuta, Younes Bouhadjar"

import os
import torch
from time import sleep
from datetime import datetime

from miprometheus.workers.worker import Worker
from miprometheus.models.model_factory import ModelFactory
from miprometheus.problems.problem_factory import ProblemFactory
from miprometheus.utils.statistics_collector import StatisticsCollector
from miprometheus.utils.statistics_aggregator import StatisticsAggregator

[docs]class Tester(Worker): """ Defines the basic ``Tester``. If defining another type of tester, it should subclass it. """
[docs] def __init__(self, name="Tester"): """ Calls the ``Worker`` constructor, adds some additional params to parser. :param name: Name of the worker (DEFAULT: "Tester"). :type name: str """ # Call base constructor to set up app state, registry and add default params. super(Tester, self).__init__(name) # Add arguments are related to the basic ``Tester``. self.parser.add_argument('--visualize', action='store_true', dest='visualize', help='Activate dynamic visualization')
[docs] def setup_global_experiment(self): """ Sets up the global test experiment for the ``Tester``: - Checks that the model to use exists on file: >>> if not os.path.isfile(flags.model) - Checks that the configuration file exists: >>> if not os.path.isfile(config_file) - Create the configuration: >>> self.params.add_config_params_from_yaml(config) The rest of the experiment setup is done in :py:func:`setup_individual_experiment()` \ to allow for multiple tests suppport. """ # Call base method to parse all command line arguments and add default sections. super(Tester, self).setup_experiment() # Check if model is present. if self.flags.model == '': print('Please pass path to and name of the file containing model to be loaded as --m parameter') exit(-1) # Check if file with model exists. if not os.path.isfile(self.flags.model): print('Model file {} does not exist'.format(self.flags.model)) exit(-2) # Extract path. self.abs_path, _ = os.path.split(os.path.dirname(os.path.abspath(self.flags.model))) # Check if config file was indicated by the user. if self.flags.config != '': config_file = self.flags.config else: # Use the "default one". config_file = self.abs_path + '/training_configuration.yaml' # Check if configuration file exists. if not os.path.isfile(config_file): print('Config file {} does not exist'.format(config_file)) exit(-3) # Check the presence of the CUDA-compatible devices. if self.flags.use_gpu and (torch.cuda.device_count() == 0): self.logger.error("Cannot use GPU as there are no CUDA-compatible devices present in the system!") exit(-4) # Get the list of configurations which need to be loaded. configs_to_load = self.recurrent_config_parse(config_file, []) # Read the YAML files one by one - but in reverse order -> overwrite the first indicated config(s) self.recurrent_config_load(configs_to_load)
# -> At this point, the Param Registry contains the configuration loaded (and overwritten) from several files.
[docs] def setup_individual_experiment(self): """ Setup individual test experiment in the case of multiple tests, or the main experiment in the case of \ one test experiment. - Set up the log directory path: >>> os.makedirs(self.log_dir, exist_ok=False) - Add a FileHandler to the logger (defined in BaseWorker): >>> self.logger.addHandler(fh) - Set random seeds: >>> self.set_random_seeds(self.params['testing'], 'testing') - Creates problem and model: >>> self.problem = ProblemFactory.build_problem(self.params['training']['problem']) >>> self.model = ModelFactory.build_model(self.params['model'], self.dataset.default_values) - Creates the DataLoader: >>> self.dataloader = DataLoader(dataset=self.problem, ...) """ # Get testing problem name. try: _ = self.params['testing']['problem']['name'] except KeyError: print("Error: Couldn't retrieve the problem name from the 'testing' section in the loaded configuration") exit(-5) # Get model name. try: _ = self.params['model']['name'] except KeyError: print("Error: Couldn't retrieve the model name from the loaded configuration") exit(-6) # Prepare output paths for logging while True: # Dirty fix: if log_dir already exists, wait for 1 second and try again try: time_str = 'test_{0:%Y%m%d_%H%M%S}'.format( if self.flags.savetag != '': time_str = time_str + "_" + self.flags.savetag self.log_dir = self.abs_path + '/' + time_str + '/' os.makedirs(self.log_dir, exist_ok=False) except FileExistsError: sleep(1) else: break # Set log dir and add the handler for the logfile to the logger. self.log_file = self.log_dir + 'tester.log' self.add_file_handler_to_logger(self.log_file) # Set random seeds in the testing section. self.set_random_seeds(self.params['testing'], 'testing') # Check if CUDA is available, if yes turn it on. self.check_and_set_cuda(self.flags.use_gpu) ################# TESTING PROBLEM ################# # Build test problem and dataloader. self.problem, self.sampler, self.dataloader = \ self.build_problem_sampler_loader(self.params['testing'],'testing') # check if the maximum number of episodes is specified, if not put a # default equal to the size of the dataset (divided by the batch size) # So that by default, we loop over the test set once. max_test_episodes = len(self.dataloader) self.params['testing']['problem'].add_default_params({'max_test_episodes': max_test_episodes}) if self.params["testing"]["problem"]["max_test_episodes"] == -1: # Overwrite the config value! self.params['testing']['problem'].add_config_params({'max_test_episodes': max_test_episodes}) # Warn if indicated number of episodes is larger than an epoch size: if self.params["testing"]["problem"]["max_test_episodes"] > max_test_episodes: self.logger.warning('Indicated maximum number of episodes is larger than one epoch, reducing it.') self.params['testing']['problem'].add_config_params({'max_test_episodes': max_test_episodes})"Setting the max number of episodes to: {}".format( self.params["testing"]["problem"]["max_test_episodes"])) ################# MODEL ################# # Create model object. self.model =['model'], self.problem.default_values) # Load the pretrained model from checkpoint. try: model_name = self.flags.model # Load parameters from checkpoint. self.model.load(model_name) except KeyError: self.logger.error("File {} indicated in the command line (--m) seems not to be a valid model checkpoint".format(model_name)) exit(-5) except Exception as e: self.logger.error(e) # Exit by following the logic: if user wanted to load the model but failed, then continuing the experiment makes no sense. exit(-6) # Turn on evaluation mode. self.model.eval() # Move the model to CUDA if applicable. if self.app_state.use_CUDA: self.model.cuda() # Log the model summary. # Export and log configuration, optionally asking the user for confirmation. self.export_experiment_configuration(self.log_dir, "testing_configuration.yaml",self.flags.confirm)
[docs] def initialize_statistics_collection(self): """ Function initializes all statistics collectors and aggregators used by a given worker, creates output files etc. """ # Create statistics collector for testing. self.testing_stat_col = StatisticsCollector() self.add_statistics(self.testing_stat_col) self.problem.add_statistics(self.testing_stat_col) self.model.add_statistics(self.testing_stat_col) # Create the csv file to store the testing statistics. self.testing_batch_stats_file = self.testing_stat_col.initialize_csv_file(self.log_dir, 'testing_statistics.csv') # Create statistics aggregator for testing. self.testing_stat_agg = StatisticsAggregator() self.add_aggregators(self.testing_stat_agg) self.problem.add_aggregators(self.testing_stat_agg) self.model.add_aggregators(self.testing_stat_agg) # Create the csv file to store the testing statistic aggregations. # Will contain a single row with aggregated statistics. self.testing_set_stats_file = self.testing_stat_agg.initialize_csv_file(self.log_dir, 'testing_set_agg_statistics.csv')
[docs] def finalize_statistics_collection(self): """ Finalizes statistics collection, closes all files etc. """ # Close all files. self.testing_batch_stats_file.close() self.testing_set_stats_file.close()
[docs] def run_experiment(self): """ Main function of the ``Tester``: Test the loaded model over the test set. Iterates over the ``DataLoader`` for a maximum number of episodes equal to the test set size. The function does the following for each episode: - Forwards pass of the model, - Logs statistics & accumulates loss, - Activate visualization if set. """ # Initialize tensorboard and statistics collection. self.initialize_statistics_collection() # Set visualization. self.app_state.visualize = self.flags.visualize # Get number of samples - depending whether using sampler or not. if self.params['testing']['dataloader']['drop_last']: # if we are supposed to drop the last (incomplete) batch. num_samples = len(self.dataloader) * \ self.params['testing']['problem']['batch_size'] elif self.sampler is not None: num_samples = len(self.sampler) else: num_samples = len(self.problem)'Testing over the entire test set ({} samples in {} episodes)'.format( num_samples, len(self.dataloader))) try: # Run test with torch.no_grad(): episode = 0 for test_dict in self.dataloader: if episode == self.params["testing"]["problem"]["max_test_episodes"]: break # Evaluate model on a given batch. logits, _ = self.predict_evaluate_collect(self.model, self.problem, test_dict, self.testing_stat_col, episode) # Export to csv - at every step. self.testing_stat_col.export_to_csv() # Log to logger - at logging frequency. if episode % self.flags.logging_interval == 0:'[Partial Test]')) if self.app_state.visualize: # Allow for preprocessing test_dict, logits = self.problem.plot_preprocessing(test_dict, logits) # Show plot, if user presses Quit - break. self.model.plot(test_dict, logits) # move to next episode. episode += 1'\n' + '='*80)'Test finished') # Export aggregated statistics. self.aggregate_and_export_statistics(self.model, self.problem, self.testing_stat_col, self.testing_stat_agg, episode, '[Full Test]') except SystemExit as e: # the training did not end properly self.logger.error('Experiment interrupted because {}'.format(e)) except KeyboardInterrupt: # the training did not end properly self.logger.error('Experiment interrupted!') finally: # Finalize statistics collection. self.finalize_statistics_collection()
[docs] def check_multi_tests(self): """ Checks if multiple tests are indicated in the testing configuration section. .. note:: If the user would like to run multiple tests, he can use the ``multi_tests`` key in the ``testing`` \ section to indicate the keys which associated values will be different for each test config. E.g. >>> # Problem parameters: >>> testing: >>> problem: >>> name: SortOfCLEVR >>> batch_size: 64 >>> data_folder: '~/data/sort-of-clevr/' >>> dataset_size: 10000 >>> split: 'test' >>> img_size: 128 >>> regenerate: False >>> >>> multi_tests: {batch_size: [64, 128], img_size: [128, 256]} .. warning:: The following constraints apply: - Assume that the indicated varying values are **leafs** of the `testing` section - The number of indicated varying values per key is the same for all keys - The indicated order of the varying values will be respected, i.e. \ >>> multi_tests: {batch_size: [64, 128], img_size: [128, 256]} and >>> multi_tests: {batch_size: [64, 128], img_size: [256, 128]} will lead to different test configs. - At least one key has varying values (but this is implicit) :return: True if the constraints above are respected, else False """ # check first if the user wants multi-tests try: _ = self.params['testing']['multi_tests']"Checking validity of the indicated values for the multiple tests") multi_tests_values = self.params['testing']['multi_tests'].to_dict() for key in multi_tests_values: # check the key is a leaf of the testing config section if not key in list(self.params['testing'].leafs()): self.logger.error("Did not find the indicated key '{}' in the leafs of the 'testing' " "config section.".format(key)) return False # check that all indicated list of values have same length n_tests = len(next(iter(multi_tests_values.values()))) if not all(len(x) == n_tests for x in multi_tests_values.values()): self.logger.error("Got varying number of elements for the indicated multiple tests values.") return False # store the number of tests to execute self.number_tests = n_tests # store the params (and the indicated values) to update self.multi_tests_params = multi_tests_values # delete them from the param registry self.params['testing'].del_config_params(key='multi_tests')'Found the following indicated values for multiple tests: {}.'.format(multi_tests_values)) return True except KeyError: return False
[docs] def update_config(self, test_index): """ Update ``self.params['testing']`` using the list of values to change for the multiple tests. :param test_index: Current test experiment index. :type test_index: int """ # If this method is used, then self.number_tests & self.multi_tests_params should be instantiated new_params = {k: v[test_index] for k, v in self.multi_tests_params.items()} self.logger.warning("Updating the testing config with: {}".format(new_params)) for leaf_key, new_value in new_params.items(): self.params['testing'].set_leaf(leaf_key, new_value) self.logger.warning("Updated the testing configuration.")'\n' + '=' * 80 + '\n') return True
def main(): """ Entry point function for the ``Tester``. """ tester = Tester() # parse args, load configuration and create all required objects. tester.setup_global_experiment() if tester.check_multi_tests(): for test_index in range(tester.number_tests):'\n' + '=' * 80 + '\n')"Starting test #{}.".format(test_index+1)) # update the testing problem config based on the current test index. tester.update_config(test_index) # finalize the experiment setup tester.setup_individual_experiment() # run the current experiment tester.run_experiment() # remove the FileHandler as it will be set again in the next individual test tester.logger.removeHandler(tester.logger.handlers[0]) else: # finalize the experiment setup tester.setup_individual_experiment() # run the experiment tester.run_experiment() if __name__ == '__main__': main()