Source code for miprometheus.helpers.index_splitter

    #!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
index_splitter.py:

    - Contains the definition of a ``Helper`` class, called :py:class:`IndexSplitter`.

"""
__author__ = "Tomasz Kornuta"

import os
from miprometheus.utils.split_indices import split_indices
from miprometheus.workers import Worker
from miprometheus.problems.problem_factory import ProblemFactory


[docs]class IndexSplitter(Worker): """ Defines the :py:class:`IndexSplitter` class. This class allows to split the list of indices indexing a dataset into 2, non-overlapping, sub-lists of \ variable lengths. These 2 lists are then saved to file (named `split_a.txt` & `split_b.txt`). This can be useful to split a training set into a training set & a validation set. These files can later be used for training / validation or testing when using \ :py:class:`torch.utils.data.SubsetRandomSampler` (instantiated with the :py:class:`miprometheus.utils.SamplerFactory`). .. note:: General usage: -- The user provides the output dir where the 2 files containing indices will be stored (`--o`) -- The user provides the problem name (`--p`) OR length of the dataset (`--l`) -- The user provides the split `--s`, which represents how many samples will be contained in the first split \ (value from 1 to l-2, which are border cases when one or the other split will contain a single index). Additionally, the user might turn ``random_sampling`` on or off by `--r` (Default: ``True``) -- when ``random_sampling`` is on, both files will contain (exclusive) random lists of indices -- when off, both files will contain ranges, i.e. `[0, s-1]` and `[s, l-1]` respectively. """
[docs] def __init__(self, name="IndexSplitter"): """ Set parser arguments. .. note:: As it does not really share any functionality with other basic workers, it does not call the base \ :py:class:`miprometheus.workers.Worker` constructor. :param name: Name of the worker (Default: "IndexSplitter"). :type name: str """ # Call base constructor to set up app state, registry and add default params. super(IndexSplitter, self).__init__(name=name, add_default_parser_args=False) # Add arguments to the specific parser. # These arguments will be shared by all basic workers. self.parser.add_argument('--outdir', dest='outdir', type=str, default=".", help='Path to the output directory where the files with indices will be stored.' ' (DEFAULT: .)') self.parser.add_argument('--problem', dest='problem_name', type=str, default='', help='Name of the problem to be splitted. (WARNING: exclusive with --l)') self.parser.add_argument('--length', dest='length', type=int, default=-1, help='Length (size) of the dataset (WARNING: exclusive with --p)') self.parser.add_argument('--split', dest='split', type=int, default=-1, help='Value indicating number of indices/samples in the first set. ' 'Value from 1 to length-2 are accepted. The two border cases mean that one ' 'or the other split will contain a single index)') self.parser.add_argument('--random_sampling_off', dest='random_sampling_off', default=False, action='store_true', help='When on, both files will contain (exclusive) random lists of indices. ' 'When off, both files will contain ranges, i.e. [0, split-1] and ' '[split, length-1] respectively')
[docs] def run(self): """ Creates two files with splits. - Parses command line arguments. - Loads the problem class (if required). - Generates two lists (or ranges) of exclusive indices. - Writes those lists to two separate files. """ # Parse arguments. self.flags, self.unparsed = self.parser.parse_known_args() # Display results of parsing. self.display_parsing_results() # Get output dir. self.out_dit = self.flags.outdir # Create - just in case. os.makedirs(self.out_dit, exist_ok=True) # Check if we can estimate length. if self.flags.problem_name == '' and self.flags.length == -1: self.logger.error('Index splitter operates on length (size) of the problem, ' 'please set problem (--p) or its length (--l).') exit(-1) # Check if user pointed only one of them. if self.flags.problem_name != '' and self.flags.length != -1: self.logger.error('Flags problem (--p) and length (--l) are exclusive, please use only one of them.') exit(-2) # Check if user set the split. if self.flags.split == -1: self.logger.error('Please set the split (--s).') exit(-3) split = self.flags.split # Build the problem. if self.flags.problem_name != '': self.params.add_default_params({'name': self.flags.problem_name}) problem = ProblemFactory.build(self.params) length = len(problem) else: length = self.flags.length # Check the splitting. if split < 1 or split > length-1: self.logger.error("Split must lie within 1 to {}-2 range, which are border cases " "when one or the other split will contain a single index.".format(length)) exit(-4) self.logger.info("Splitting dataset of length {} into splits of size {} and {}.".format(length, split, length - split)) # Split the indices. split_a, split_b = split_indices(length, split, self.logger, self.flags.random_sampling_off == False) # Write splits to files. name_a = os.path.expanduser(self.flags.outdir)+'split_a.txt' split_a.tofile(name_a, sep=",", format="%s") # Write splits to files. name_b = os.path.expanduser(self.flags.outdir)+'split_b.txt' split_b.tofile(name_b, sep=",", format="%s") self.logger.info("Splits written to {} ({} indices) and {} ({} indices).".format(name_a, len(split_a), name_b, len(split_b)))
# Finished. def main(): """ Entry point function for the :py:class:`IndexSplitter`. """ worker = IndexSplitter() # parse args and do the splitting. worker.run() if __name__ == '__main__': main()