Source code for miprometheus.problems.video_to_class.seq_mnist_to_class.permuted_sequential_row_mnist

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""permuted_sequential_row_mnist.py: loads the `MNIST` dataset using ``torchvision`` and\
 apply a permutation over the rows"""
__author__ = "Younes Bouhadjar & Vincent Marois"

import torch
from torchvision import datasets, transforms

from miprometheus.utils.data_dict import DataDict
from miprometheus.problems.video_to_class.video_to_class_problem import VideoToClassProblem


[docs]class PermutedSequentialRowMnist(VideoToClassProblem): """ The Permuted MNIST is a sequence of classification tasks in which the rows\ of the input images are swapped with a random permutation. .. warning:: The dataset is not originally split into a training set, validation set and test set; only\ training and test set. It is recommended to use a validation set. ``torch.utils.data.SubsetRandomSampler`` is recommended. """
[docs] def __init__(self, params): """ Initializes PermutedSequentialRowMnist problem: - Calls ``problems.problem.VideoToClassProblem`` class constructor, - Sets following attributes using the provided ``params``: - ``self.root_dir`` (`string`) : Root directory of dataset where ``processed/training.pt``\ and ``processed/test.pt`` will be saved, - ``self.use_train_data`` (`bool`, `optional`) : If True, creates dataset from ``training.pt``,\ otherwise from ``test.pt`` - ``self.defaut_values`` : >>> self.default_values = {'nb_classes': 10, >>> 'num_channels': 1, >>> 'width': 28, >>> 'height': 28} - ``self.data_definitions`` : >>> self.data_definitions = {'images': {'size': [-1, 28, 1, 1, 28], 'type': [torch.Tensor]}, >>> 'mask': {'size': [-1, 28, 1], 'type': [torch.Tensor]}, >>> 'targets': {'size': [-1, 28, 1], 'type': [torch.Tensor]}, >>> 'targets_label': {'size': [-1, 1], 'type': [list, str]} >>> } :param params: Dictionary of parameters (read from configuration ``.yaml`` file). """ # Call base class constructor. super(PermutedSequentialRowMnist, self).__init__(params) # Retrieve parameters from the dictionary. self.use_train_data = params['use_train_data'] self.root_dir = params['root_dir'] self.num_rows = 28 self.num_columns = 28 # define the default_values dict: holds parameters values that a model may need. self.default_values = {'nb_classes': 10, 'num_channels': 1, 'width': 28, 'height': 28, } self.data_definitions = {'images': {'size': [-1, 28, 1, 1, 28], 'type': [torch.Tensor]}, 'mask': {'size': [-1, 28, 1], 'type': [torch.Tensor]}, 'targets': {'size': [-1, 28, 1], 'type': [torch.Tensor]}, 'targets_label': {'size': [-1, 1], 'type': [list, str]} } self.name = 'PermutedSequentialRowMNIST' # Class names. self.labels = 'Zero One Two Three Four Five Six Seven Eight Nine'.split(' ') # define transforms pixel_permutation = torch.randperm(self.num_rows) transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x[:, pixel_permutation])]) # load the dataset self.dataset = datasets.MNIST(self.root_dir, train=self.use_train_data, download=True, transform=transform) self.length = len(self.dataset)
[docs] def __getitem__(self, index): """ Getter method to access the dataset and return a sample. :param index: index of the sample to return. :type index: int :return: ``DataDict({'images','targets', 'targets_label'})``, with: - images: Image, - mask, - targets: Index of the target class - targets_label: Label of the target class (cf ``self.labels``) """ # get sample img, target = self.dataset.__getitem__(index) # get label label = self.labels[target.data] # create mask mask = torch.IntTensor(self.num_rows,1).zero_() mask[-1,0] = 1 data_dict = DataDict({key: None for key in self.data_definitions.keys()}) data_dict['images'] = img.view(28,1,1,28) data_dict['mask'] = mask data_dict['targets'] = target.expand((28,1)) data_dict['targets_label'] = label return data_dict
[docs] def collate_fn(self, batch): """ Combines a list of ``DataDict`` (retrieved with ``__getitem__`` ) into a batch. .. note:: This function wraps a call to ``default_collate`` and simply returns the batch as a ``DataDict``\ instead of a dict. Multi-processing is supported as the data sources are small enough to be kept in memory\ (`training.pt` has a size of 47.5 MB). :param batch: list of individual ``DataDict`` samples to combine. :return: ``DataDict({'images','targets', 'targets_label'})`` containing the batch. """ return DataDict({key: value for key, value in zip(self.data_definitions.keys(), super(PermutedSequentialRowMnist, self).collate_fn(batch).values())})
if __name__ == "__main__": """ Tests sequence generator - generates and displays a random sample""" # Load parameters. from miprometheus.utils.param_interface import ParamInterface params = ParamInterface() params.add_default_params({'use_train_data': True, 'root_dir': '~/data/mnist'}) batch_size = 64 # Create problem. problem = PermutedSequentialRowMnist(params) # get a sample sample = problem[0] print(repr(sample)) # test whether data structures match expected definitions # images should be (batch size x sequence x channel x height x width) # as this is a sample, we should have (sequence x channel x height x width) == (28, 1, 1, 28) assert sample['images'].shape == torch.ones((28, 1, 1, 28)).shape, "Unit test failed! Expected images shape {} but got {}".format(torch.ones((28*28, 1, 1, 1)).shape, sample['images'].shape) # mask should be (sequence x class) == (28, 1) assert sample['mask'].shape == torch.ones((28,1)).shape, "Unit test failed! Expected mask shape {} but got {}".format(torch.ones((28*28,1)).shape, sample['mask'].shape) # targets should be (sequence x class) == (28, 1) assert sample['targets'].shape == torch.ones((28,1)).shape, "Unit test failed! Expected targets shape {} but got {}".format(torch.ones((28*28,1)).shape, sample['targets'].shape) # targets_label should be (class) == (1) assert type(sample['targets_label']) == type(' ') , "Unit test failed! Expected target_labels to be str but got {}".format(type(sample['targets_label'])) print('__getitem__ works.') # wrap DataLoader on top of this Dataset subclass from torch.utils.data import DataLoader dataloader = DataLoader(dataset=problem, collate_fn=problem.collate_fn, batch_size=batch_size, shuffle=True, num_workers=8) # try to see if there is a speed up when generating batches w/ multiple workers import time s = time.time() for i, batch in enumerate(dataloader): print('Batch # {} - {}'.format(i, type(batch))) print('Number of workers: {}'.format(dataloader.num_workers)) print('time taken to exhaust the dataset for a batch size of {}: {}s'.format(batch_size, time.time() - s)) # Get a single batch from data loader. batch = next(iter(dataloader)) # Reshape image for display. In permuted sequential row mnist, each sequence has 28 entries of a row of pixels (1,28 image). We will go from a 28-long sequence of (1,28) images to a 1-long sequence of (28,28) images for testing. batch['images'] = batch['images'].view(batch_size,1,1,problem.num_columns,problem.num_rows) problem.show_sample(batch, 0) print('Unit test completed')