Source code for miprometheus.problems.image_to_class.cifar10

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
cifar10.py: contains code for loading the `CIFAR10` dataset using ``torchvision``.
"""
__author__ = "Younes Bouhadjar & Vincent Marois"

import os
import torch
import numpy as np
from torchvision import datasets, transforms


from miprometheus.utils.data_dict import DataDict
from miprometheus.problems.image_to_class.image_to_class_problem import ImageToClassProblem


[docs]class CIFAR10(ImageToClassProblem): """ Image classification problem using the CIFAR-10 dataset. Please see reference here: https://www.cs.toronto.edu/~kriz/cifar.html .. warning:: The dataset is not originally split into a training set, validation set and test set; only \ training and test set. It is recommended to use a validation set. ``torch.utils.data.SubsetRandomSampler`` is recommended. """
[docs] def __init__(self, params): """ Initializes the CIFAR-10 problem: - Calls ``problems.problem.ImageToClassProblem`` class constructor, - Sets following attributes using the provided ``params``: - ``self.data_folder`` (`string`) : Root directory of dataset where the directory \ ``cifar-10-batches-py`` will be saved, - ``self.use_train_data`` (`bool`, `optional`) : If ``True``, creates dataset from training set, \ otherwise creates from test set, - ``self.resize`` : (optional) resize the images to `[h, w]` if set, - ``self.defaut_values`` : >>> self.default_values = {'num_classes': 10, >>> 'num_channels': 3, >>> 'width': self.width, # (DEFAULT: 32) >>> 'height': self.height} # DEFAULT: 32) - ``self.data_definitions`` : >>> self.data_definitions = {'images': {'size': [-1, 3, self.height, self.width], 'type': [torch.Tensor]}, >>> 'targets': {'size': [-1], 'type': [torch.Tensor]}, >>> 'targets_label': {'size': [-1, 1], 'type': [list, str]} >>> } .. warning:: Resizing images might cause a significant slow down in batch generation. .. note:: The following is set by default: >>> params = {'data_folder': '~/data/cifar10', >>> 'use_train_data': True} :param params: Dictionary of parameters (read from configuration ``.yaml`` file). :type params: miprometheus.utils.ParamInterface """ # Call base class constructors. super(CIFAR10, self).__init__(params, 'CIFAR10') # Set default parameters. params.add_default_params({'data_folder': '~/data/cifar10', 'use_train_data': True}) # Get absolute path. data_folder = os.path.expanduser(params['data_folder']) # Retrieve parameters from the dictionary. self.use_train_data = params['use_train_data'] # Add transformations depending on the resizing option. if ('resize' in self.params): # Check the desired size. if len(self.params['resize']) != 2: self.logger.error("'resize' field must contain 2 values: the desired height and width") exit(-1) # Output image dimensions. self.height = self.params['resize'][0] self.width = self.params['resize'][1] self.num_channels = 3 # Up-scale and transform to tensors. transform = transforms.Compose([transforms.Resize((self.height, self.width)), transforms.ToTensor()]) self.logger.warning('Upscaling the images to [{}, {}]. Slows down batch generation.'.format( self.width, self.height)) else: # Default MNIST settings. self.width = 32 self.height = 32 self.num_channels = 3 # Simply turn to tensor. transform = transforms.Compose([transforms.ToTensor()]) # Define the default_values dict: holds parameters values that a model may need. self.default_values = {'num_classes': 10, 'num_channels': self.num_channels, 'width': self.width, 'height': self.height} self.data_definitions = {'images': {'size': [-1, self.num_channels, self.height, self.width], 'type': [torch.Tensor]}, 'targets': {'size': [-1], 'type': [torch.Tensor]}, 'targets_label': {'size': [-1, 1], 'type': [list, str]} } # load the dataset self.dataset = datasets.CIFAR10(root=data_folder, train=self.use_train_data, download=True, transform=transform) # type(self.train_dataset) = <class 'torchvision.datasets.cifar.CIFAR10'> # -> inherits from torch.utils.data.Dataset self.length = len(self.dataset) # Class names. self.labels = 'Airplane Automobile Bird Cat Deer Dog Frog Horse Shipe Truck'.split(' ')
[docs] def __getitem__(self, index): """ Getter method to access the dataset and return a sample. :param index: index of the sample to return. :type index: int :return: ``DataDict({'images','targets', 'targets_label'})``, with: - images: Image, resized if indicated in ``params``, - targets: Index of the target class - targets_label: Label of the target class (cf ``self.labels``) """ img, target = self.dataset.__getitem__(index) target = torch.tensor(target) label = self.labels[target.data] data_dict = DataDict({key: None for key in self.data_definitions.keys()}) data_dict['images'] = img data_dict['targets'] = target data_dict['targets_label'] = label return data_dict
[docs] def collate_fn(self, batch): """ Combines a list of ``DataDict`` (retrieved with ``__getitem__`` ) into a batch. .. note:: This function wraps a call to ``default_collate`` and simply returns the batch as a ``DataDict``\ instead of a dict. Multi-processing is supported as the data sources are small enough to be kept in memory\ (`self.root-dir/cifar-10-batches/data_batch_i` have a size of 31.0 MB). :param batch: list of individual ``DataDict`` samples to combine. :return: ``DataDict({'images','targets', 'targets_label'})`` containing the batch. """ return DataDict({key: value for key, value in zip(self.data_definitions.keys(), super(CIFAR10, self).collate_fn(batch).values())})
if __name__ == "__main__": """ Tests sequence generator - generates and displays a random sample""" # set the seeds np.random.seed(0) torch.manual_seed(0) # Load parameters. from miprometheus.utils.param_interface import ParamInterface params = ParamInterface() # using the default values batch_size = 64 # Create problem. cifar10 = CIFAR10(params) # get a sample sample = cifar10[0] print('__getitem__ works.\n') # wrap DataLoader on top of this Dataset subclass from torch.utils.data import DataLoader dataloader = DataLoader(dataset=cifar10, collate_fn=cifar10.collate_fn, batch_size=batch_size, shuffle=True, num_workers=0) # try to see if there is a speed up when generating batches w/ multiple workers import time s = time.time() for i, batch in enumerate(dataloader): print('Batch # {} - {}'.format(i, type(batch))) print('Number of workers: {}'.format(dataloader.num_workers)) print('time taken to exhaust the dataset for a batch size of {}: {}s'.format(batch_size, time.time() - s)) # Display single sample (0) from batch. batch = next(iter(dataloader)) cifar10.show_sample(batch, 0) print('Unit test completed')