Source code for miprometheus.utils.sampler_factory

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
sampler_factory.py: Factory building samplers used by PyTorch's DataLoaders.

"""
__author__ = "Tomasz Kornuta"

import os
import logging
import torch.utils.data.sampler


[docs]class SamplerFactory(object): """ Class returning sampler depending on the name provided in the \ list of parameters. """
[docs] @staticmethod def build(problem, params): """ Static method returning particular sampler, depending on the name \ provided in the list of parameters & the specified problem class. :param problem: Instance of an object derived from the Problem class. :type problem: ``problems.Problem`` :param params: Parameters used to instantiate the sampler. :type params: ``utils.param_interface.ParamInterface`` ..note:: ``params`` should contains the exact (case-sensitive) class name of the sampler to instantiate. .. warning:: ``torch.utils.data.sampler.WeightedRandomSampler``, ``torch.utils.data.sampler.BatchSampler``, \ ``torch.utils.data.sampler.DistributedSampler`` are not yet supported. .. note:: ``torch.utils.data.sampler.SubsetRandomSampler`` expects indices to index a subset of the dataset. \ Currently, the user can specify these indices using one of the following options: - Option 1: range. >>> indices = range(20) - Option 2: range as str. >>> range_str = '0, 20' - Option 3: list of indices. >>> yaml_list = yaml.load('[0, 2, 5, 10]') - Option 4: name of the file containing indices. >>> filename = "~/data/mnist/training_indices.txt" :return: Instance of a given sampler or ``None`` if the section not present or couldn't build the sampler. """ logging.basicConfig(level=logging.INFO) logger = logging.getLogger('SamplerFactory') # Check if sampler is required, i.e. 'sampler' section is empty. if not params: logger.info('The sampler configuration section is not present.') return None try: # Check presence of the name attribute. if 'name' not in params: raise Exception("The sampler configuration section does not contain the key 'name'.") # Get the class name. name = params['name'] # Verify that the specified class is in the samplers package. if name not in dir(torch.utils.data.sampler): raise Exception("Could not find the specified class '{}' in the samplers package".format(name)) # Get the actual class. sampler_class = getattr(torch.utils.data.sampler, name) # Ok, proceed. logger.info('Loading the {} sampler from {}'.format(name, sampler_class.__module__)) # Handle "special" case. if sampler_class.__name__ == 'SubsetRandomSampler': # Check presence of the name attribute. if 'indices' not in params: raise Exception("The sampler configuration section does not contain the key 'indices' " "required by SubsetRandomSampler.") indices = params['indices'] # Analyze the type. if type(indices) == str: # Try to open the file. try: # from expanduser()'s doc: If the expansion fails or if the path does not begin # with a tilde, the path is returned unchanged. -> So operation below should be safe. file = open(os.path.expanduser(indices), "r") # Read the file. indices = file.readline() file.close() except Exception: # Ok, this is not a file. pass finally: # Try to process it as a string. # Get the digits. digits = indices.split(',') indices = [int(x) for x in digits] else: # Assume that type(indices) is a list of ints. digits = indices # Finally, we got the list of digits. if len(digits) == 2: # Create a range. indices = range(int(digits[0]), int(digits[1])) # Else: use them as they are # Check if indices are within range. if max(indices) >= len(problem): logger.error("SubsetRandomSampler cannot work properly when indices are out of range ({}) " "considering that there are {} samples in the problem!".format(max(indices), len(problem))) exit(-1) # Create the sampler object. sampler = sampler_class(indices) elif sampler_class.__name__ in ['WeightedRandomSampler', 'BatchSampler', 'DistributedSampler']: # Sorry, don't support those. Yet;) logger.error("Sampler Factory currently does not support {} sampler. Please pick one of the others " "or use defaults random sampling.".format(sampler_class.__name__)) exit(-2) else: # Create "regular" sampler. sampler = sampler_class(problem) # Return sampler. return sampler except Exception as e: logger.error(e) logger.warning("Using default sampling without sampler.") return None
if __name__ == "__main__": """ Tests the factory. """ from miprometheus.utils.param_interface import ParamInterface import yaml # Problem. class TestProblem(object): def __len__(self): return 50 # All samplers operate on TestProblem only, # whereas SubsetRandomSampler additionally accepts 'indices' with the following options: # Option 1: range. indices = range(20) # Option 2: range as str. range_str = '0, 20' # Option 3: list of indices. yaml_list = yaml.load('[0, 2, 5, 10]') # Option 4: name of the file containing indices. filename = "~/data/mnist/training_indices.txt" params = ParamInterface() params.add_default_params({'name': 'SubsetRandomSampler', 'indices': yaml_list}) sampler = SamplerFactory.build(TestProblem(), params) print(type(sampler)) for i, index in enumerate(sampler): print('{}: index {}'.format(i, index))