Source code for miprometheus.workers.offline_trainer

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) IBM Corporation 2018
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.


    - This file contains the implementation of the ``OfflineTrainer``, which inherits from ``Trainer``. \
    The ``OfflineTrainer`` is based on epochs.

__author__ = "Vincent Marois, Tomasz Kornuta"

import torch
import numpy as np

from miprometheus.workers.trainer import Trainer

[docs]class OfflineTrainer(Trainer): """ Implementation for the epoch-based ``OfflineTrainer``. ..note:: The default ``OfflineTrainer`` is based on epochs. \ An epoch is defined as passing through all samples of a finite-size dataset.\ The ``OfflineTrainer`` allows to loop over all samples from the training set many times i.e. in many epochs. \ When an epochs finishes, it performs a similar step for the validation set and collects the statistics. """
[docs] def __init__(self, name="OfflineTrainer"): """ Only calls the ``Trainer`` constructor as the initialization phase is identical to the ``Trainer``. :param name: Name of the worker (DEFAULT: "OfflineTrainer"). :type name: str """ # Call base constructor to set up app state, registry and add default params. super(OfflineTrainer, self).__init__(name)
[docs] def setup_experiment(self): """ Sets up an experiment for the ``OfflineTrainer``: - Calls base class setup_experiment to parse the command line arguments, - Sets up the terminal conditions (loss threshold, episodes (optional) & epochs limits). """ # Call base method to parse all command line arguments, load configuration, create problems and model etc. super(OfflineTrainer, self).setup_experiment() ################# TERMINAL CONDITIONS #################'Terminal conditions:\n' + '='*80) # Terminal condition I: loss. self.params['training']['terminal_conditions'].add_default_params({'loss_stop': 1e-5}) self.loss_stop = self.params['training']['terminal_conditions']['loss_stop']"Setting Loss Stop threshold to {}".format(self.loss_stop)) # In this trainer, Partial Validation is optional. self.params['validation'].add_default_params({'partial_validation_interval': -1}) self.partial_validation_interval = self.params['validation']['partial_validation_interval'] if self.partial_validation_interval <= 0:"Partial Validation deactivated") else:"Partial Validation activated with interval equal to {} episodes".format(self.partial_validation_interval)) # Terminal condition II: max epochs. Mandatory. self.params["training"]["terminal_conditions"].add_default_params({'epoch_limit': 10}) self.epoch_limit = self.params["training"]["terminal_conditions"]["epoch_limit"] if self.epoch_limit <= 0: self.logger.error("OffLine Trainer relies on epochs, thus Epoch Limit must be a positive number!") exit(-5) else:"Setting the Epoch Limit to: {}".format(self.epoch_limit)) # Calculate the epoch size in terms of episodes. self.epoch_size = len(self.training_dataloader)'Epoch size in terms of training episodes: {}'.format(self.epoch_size)) # Terminal condition III: max episodes. Optional. self.params["training"]["terminal_conditions"].add_default_params({'episode_limit': -1}) self.episode_limit = self.params['training']['terminal_conditions']['episode_limit'] if self.episode_limit < 0:"Termination based on Episode Limit is disabled") # Set to infinity. self.episode_limit = np.Inf else:"Setting the Episode Limit to: {}".format(self.episode_limit))'\n' + '='*80) # Export and log configuration, optionally asking the user for confirmation. self.export_experiment_configuration(self.log_dir, "training_configuration.yaml", self.flags.confirm)
[docs] def run_experiment(self): """ Main function of the ``Trainer``. Iterates over the number of epochs of the training set. .. note:: Because of the export of stats, weights and gradients to TensorBoard, we need to\ keep track of the current episode index from the start of the training, even \ though the Worker runs on epoch. .. warning:: The test for terminal conditions (e.g. convergence) is done at the end of each epoch. \ The terminal conditions are as follows: - I. The loss is below the specified threshold (using the full validation loss), - TODO: II. Early stopping is set and the full validation loss did not change by delta \ for the indicated number of epochs, - III. The maximum number of epochs has been met, - IV. The maximum number of episodes has been met (optional). Besides, the user can always stop experiment by pressing 'Stop experiment' during visualization. The function does the following for each epoch: - Executes the ``initialize_epoch()`` & ``finish_epoch()`` function of the ``Problem`` class, - For each episode: - Resets the gradients, - Forwards pass of the model, - Logs statistics and exports to TensorBoard (if set), - Computes gradients and update weights, - Activates visualization if set (vis. level 0), - Validates the model on a batch according to the validation frequency. - At the end of epoch: - Handles curriculum learning (if set), - Validates the model on the full validation set, logs the statistics \ and visualizes on a random batch if set (vis. level 1 or 2) - Checks the above terminal conditions. The last validation on the full set is done additionally at the end on training, \ with optional visualization of a random batch if set (vis. level 3). """ # Initialize TensorBoard and statistics collection. self.initialize_statistics_collection() self.initialize_tensorboard() try: ''' Main training and validation loop. ''' # Reset the counter. episode = -1 # Set initial status. training_status = "Not Converged" # Iterate over epochs. for epoch in range(self.epoch_limit):'Starting next epoch: {}'.format(epoch)) # Inform the training problem class that epoch has started. self.training_problem.initialize_epoch(epoch) # Empty the statistics collector. self.training_stat_col.empty() # Exhaust training set. for training_dict in self.training_dataloader: # "Move on" to the next episode. episode += 1 # reset all gradients self.optimizer.zero_grad() # Check the visualization flag - Set it if visualization is wanted during # training & validation episodes. if 0 <= self.flags.visualize <= 1: self.app_state.visualize = True else: self.app_state.visualize = False # Turn on training mode for the model. self.model.train() # 1. Perform forward step, get predictions and compute loss. logits, loss = self.predict_evaluate_collect(self.model, self.training_problem, training_dict, self.training_stat_col, episode, epoch) # 2. Backward gradient flow. loss.backward() # Check the presence of the 'gradient_clipping' parameter. try: # if present - clip gradients to a range (-gradient_clipping, gradient_clipping) val = self.params['training']['gradient_clipping'] torch.nn.utils.clip_grad_value_(self.model.parameters(), val) except KeyError: # Else - do nothing. pass # 3. Perform optimization. self.optimizer.step() # 4. Log collected statistics. # 4.1. Export to csv - at every step. self.training_stat_col.export_to_csv() # 4.2. Export data to tensorboard - at logging frequency. if (self.training_batch_writer is not None) and (episode % self.flags.logging_interval == 0): self.training_stat_col.export_to_tensorboard() # Export histograms. if self.flags.tensorboard >= 1: for name, param in self.model.named_parameters(): try: self.training_batch_writer.add_histogram(name,, episode, bins='doane') except Exception as e: self.logger.error(" {} :: data :: {}".format(name, e)) # Export gradients. if self.flags.tensorboard >= 2: for name, param in self.model.named_parameters(): try: self.training_batch_writer.add_histogram(name + '/grad',, episode, bins='doane') except Exception as e: self.logger.error(" {} :: grad :: {}".format(name, e)) # 4.3. Log to logger - at logging frequency. if episode % self.flags.logging_interval == 0: # 5. Check visualization of training data. if self.app_state.visualize: # Allow for preprocessing training_dict, logits = self.training_problem.plot_preprocessing(training_dict, logits) # Show plot, if user will press Stop then a SystemExit exception will be thrown. self.model.plot(training_dict, logits) # 6. Validate and (optionally) save the model. if (self.partial_validation_interval > 0) and (episode % self.partial_validation_interval) == 0: # Check visualization flag if 1 <= self.flags.visualize <= 2: self.app_state.visualize = True else: self.app_state.visualize = False # Perform validation. self.validate_on_batch(self.validation_batch, episode, epoch) # Aggregate statistics, but do not display them in log. # self.aggregate_and_export_statistics(self.model, self.validation_problem, # self.validation_stat_col, self.validation_stat_agg, # episode, '[Partial Validation]', False) # Do not save the model: OfflineTrainer uses the full set to determine whether to save or not. # III. The episodes number limit has been reached. if episode+1 >= self.episode_limit: training_status = "Not converged: Episode Limit reached" break # the inner loop. # Epoch just ended! (or episode limit). # Inform the problem class that the epoch has ended. self.training_problem.finalize_epoch(epoch) # Aggregate training statistics for the epoch. self.aggregate_and_export_statistics(self.model, self.training_problem, self.training_stat_col, self.training_stat_agg, episode, '[Epoch {}]'.format(epoch)) # Apply curriculum learning - change some of the Problem parameters self.curric_done = self.training_problem.curriculum_learning_update_params(episode) # Perform full validation! # Check visualization flag - turn on visualization for last validation if needed. if 1 <= self.flags.visualize <= 2: self.app_state.visualize = True else: self.app_state.visualize = False # Validate over the entire validation set. self.validate_on_set(episode, epoch) # Save the model using the average validation loss., training_status, self.training_stat_agg, self.validation_stat_agg) # Terminal conditions. # I - the loss is < threshold (only when curriculum learning is finished if set.) # We check that condition only in validation step! if self.curric_done or not self.must_finish_curriculum: # Check the Full Validation loss. if self.validation_stat_agg["loss"] < self.loss_stop: # Change the status... training_status = "Converged (Full Validation Loss went below Loss Stop threshold)" # ... and THEN try to save the model using the average validation loss., training_status, self.training_stat_agg, self.validation_stat_agg) break # II. Early stopping is set and loss hasn't improved by delta in n epochs. # early_stopping(index=epoch, avg_valid_loss). (TODO: coming in next release) # training_status = 'Early Stopping.' # III. The episodes number limit has been reached. (2nd check) if episode+1 >= self.episode_limit: break # the outer loop. # IV. The epoch number limit has been reached, condition is already made in for loop. ''' End of main training and validation loop. Perform final full validation. ''' # Try to save the model only if we hit the epoch limit. if epoch+1 >= self.epoch_limit: # Change the status. training_status = "Not converged: Epoch Limit reached" # Display status.'\n' + '='*80)'Training finished because {}'.format(training_status)) # Check visualization flag - turn on visualization for last validation if needed. if 2 <= self.flags.visualize <= 3: self.app_state.visualize = True else: self.app_state.visualize = False # Validate over the entire validation set. self.validate_on_set(episode, epoch) # Try to save the model only if we hit the epoch limit. if epoch+1 >= self.epoch_limit: # Try to save the model using the average validation loss., training_status, self.training_stat_agg, self.validation_stat_agg)'Experiment finished!') except SystemExit as e: # the training did not end properly self.logger.error('Experiment interrupted because {}'.format(e)) except KeyboardInterrupt: # the training did not end properly self.logger.error('Experiment interrupted!') finally: # Finalize statistics collection. self.finalize_statistics_collection() self.finalize_tensorboard()
def main(): """ Entry point function for the ``OfflineTrainer``. """ trainer = OfflineTrainer() # parse args, load configuration and create all required objects. trainer.setup_experiment() # GO! trainer.run_experiment() if __name__ == '__main__': main()