Source code for miprometheus.utils.statistics_collector

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
statistics_collector.py: contains class used for collection and export of statistics during training,\
 validation and testing.

 """
__author__ = "Tomasz Kornuta & Vincent Marois"

from collections import Mapping


[docs]class StatisticsCollector(Mapping): """ Specialized class used for the collection and export of statistics during\ training, validation and testing. Inherits :py:class:`collections.Mapping`, therefore it offers functionality close to a ``dict``. """
[docs] def __init__(self): """ Initialization - creates dictionaries for statistics and formatting. """ super(StatisticsCollector, self).__init__() # Set default "output streams" for none. self.tb_writer = None self.csv_file = None self.statistics = dict() self.formatting = dict()
[docs] def add_statistic(self, key, formatting): """ Add a statistic to collector. The value of associated to the key is of type ``list``. :param key: Key of the statistic. :type key: str :param formatting: Formatting that will be used when logging and exporting to CSV. """ self.formatting[key] = formatting # instantiate associated value as list. self.statistics[key] = list()
[docs] def __getitem__(self, key): """ Get statistics value for given key. :param key: Key to value in parameters. :type key: str :return: Statistics value list associated with given key. """ return self.statistics[key]
[docs] def __setitem__(self, key, value): """ Add value to the list of the statistic associated with a given key. :param key: Key to value in parameters. :param value: Statistics value to append to the list associated with given key. """ self.statistics[key].append(value)
[docs] def __delitem__(self, key): """ Delete the specified key. :param key: Key to be deleted. """ del self.statistics[key]
[docs] def __len__(self): """ Returns "length" of ``self.statistics`` (i.e. number of tracked values). """ return self.statistics.__len__()
[docs] def __iter__(self): """ Iterator. """ return self.statistics.__iter__()
[docs] def empty(self): """ Empty the list associated to the keys of the current statistics collector. """ for key in self.statistics.keys(): del self.statistics[key][:]
[docs] def initialize_csv_file(self, log_dir, filename): """ Method creates new csv file and initializes it with a header produced on the base of statistics names. :param log_dir: Path to file. :type log_dir: str :param filename: Filename to be created. :type filename: str :return: File stream opened for writing. """ header_str = '' # Iterate through keys and concatenate them. for key in self.statistics.keys(): header_str += key + "," # Remove last coma and add \n. header_str = header_str[:-1] + '\n' # Open file for writing. self.csv_file = open(log_dir + filename, 'w', 1) self.csv_file.write(header_str) return self.csv_file
[docs] def export_to_csv(self, csv_file=None): """ Method writes current statistics to csv using the possessed formatting. :param csv_file: File stream opened for writing, optional """ # Try to use the remembered one. if csv_file is None: csv_file = self.csv_file # If it is still None - well, we cannot do anything more. if csv_file is None: return # Iterate through values and concatenate them. values_str = '' for key, value in self.statistics.items(): # Get formatting - using '{}' as default. format_str = self.formatting.get(key, '{}') # Add value to string using formatting. values_str += format_str.format(value[-1]) + "," # Remove last coma and add \n. values_str = values_str[:-1] + '\n' csv_file.write(values_str)
[docs] def export_to_checkpoint(self): """ This method exports the collected data into a dictionary using the associated formatting. """ chkpt = {} # Iterate through key, values and format them. for key, value in self.statistics.items(): # Get formatting - using '{}' as default. format_str = self.formatting.get(key, '{}') # Add to dict. chkpt[key] = format_str.format(value[-1]) return chkpt
[docs] def export_to_string(self, additional_tag=''): """ Method returns current statistics in the form of string using the possessed formatting. :param additional_tag: An additional tag to append at the end of the created string. :type additional_tag: str :return: String being the concatenation of the statistics names & values. """ # Iterate through keys and values and concatenate them. stat_str = '' for key, value in self.statistics.items(): stat_str += key + ' ' # Get formatting - using '{}' as default. format_str = self.formatting.get(key, '{}') # Add value to string using formatting. stat_str += format_str.format(value[-1]) + "; " # Remove last two element. stat_str = stat_str[:-2] + " " + additional_tag return stat_str
[docs] def initialize_tensorboard(self, tb_writer): """ Memorizes the writer that will be used with this collector. """ self.tb_writer = tb_writer
[docs] def export_to_tensorboard(self, tb_writer=None): """ Method exports current statistics to tensorboard. :param tb_writer: TensorBoard writer, optional. :type tb_writer: :py:class:`tensorboardX.SummaryWriter` """ # Get episode number. episode = self.statistics['episode'][-1] if tb_writer is None: tb_writer = self.tb_writer # If it is still None - well, we cannot do anything more. if tb_writer is None: return # Iterate through keys and values and concatenate them. for key, value in self.statistics.items(): # Skip episode. if key == 'episode': continue tb_writer.add_scalar(key, value[-1], episode)
if __name__ == "__main__": stat_col = StatisticsCollector() stat_col.add_statistic('loss', '{:12.10f}') stat_col.add_statistic('episode', '{:06d}') stat_col.add_statistic('acc', '{:2.3f}') stat_col['episode'] = 0 stat_col['loss'] = 0.7 stat_col['acc'] = 100 csv_file = stat_col.initialize_csv_file('./', 'collector_test.csv') stat_col.export_to_csv(csv_file) print(stat_col.export_to_string()) stat_col['episode'] = 1 stat_col['loss'] = 0.7 stat_col['acc'] = 99.3 stat_col.add_statistic('seq_length', '{:2.0f}') stat_col['seq_length'] = 5 stat_col.export_to_csv(csv_file) print(stat_col.export_to_string('[Validation]')) stat_col.empty() for k in stat_col: print('key: {} - value {}:'.format(k, stat_col[k]))