Source code for miprometheus.models.relational_net.relational_network

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
relational_network.py: contains the implementation of the Relational Network model from DeepMind.
See the reference paper here: https://arxiv.org/pdf/1706.01427.pdf.
"""
__author__ = "Vincent Marois"

import torch

from miprometheus.models.model import Model

from miprometheus.models.relational_net.conv_input_model import ConvInputModel
from miprometheus.models.relational_net.functions import PairwiseRelationNetwork, SumOfPairsAnalysisNetwork


[docs]class RelationalNetwork(Model): """ Implementation of the Relational Network (RN) model. Questions are processed with an LSTM to produce a question embedding, and images are processed \ with a CNN to produce a set of objects for the RN. 'Objects' are constructed using feature-map vectors \ from the convolved image. The RN considers relations across all pairs of objects, conditioned on the question embedding, and integrates all these relations to answer the question. Reference paper: https://arxiv.org/abs/1706.01427. The CNN model used for the image encoding is located in ``conv_input_model.py``. The MLPs (g_theta & f_phi) are in ``functions.py``. .. warning: This implementation has only been tested on the ``SortOfCLEVR`` problem class proposed in this \ framework and will require modification to work on the CLEVR dataset (also proposed in this framework).\ This should be addressed in a future release. """
[docs] def __init__(self, params, problem_default_values_={}): """ Constructor. Instantiates the CNN model (4 layers), and the 2 Multi Layer Perceptrons. :param params: dictionary of parameters (read from the ``.yaml`` configuration file.) :param problem_default_values_: default values coming from the ``Problem`` class. :type problem_default_values_: dict. """ # call base constructor super(RelationalNetwork, self).__init__(params, problem_default_values_) self.name = 'RelationalNetwork' # instantiate conv input model for image encoding self.cnn_model = ConvInputModel() try: # get image information from the problem class self.num_channels = problem_default_values_['num_channels'] # number of channels self.height = problem_default_values_['height'] self.width = problem_default_values_['width'] self.question_size = problem_default_values_['question_size'] # number of output nodes self.nb_classes = problem_default_values_['num_classes'] except KeyError: self.logger.warning("Couldn't retrieve one or more value(s) from problem_default_values_.") # compute the length of the input to the g_theta MLP: input_size = ( self.cnn_model.conv4.out_channels + 2) *2 + self.question_size # instantiate network to compare regions pairwise self.pair_network = PairwiseRelationNetwork(input_size=input_size) # instantiate network to analyse the sum of the pairs self.sum_network = SumOfPairsAnalysisNetwork(output_size=self.nb_classes) self.data_definitions = {'images': {'size': [-1, self.num_channels, self.height, self.width], 'type': [torch.Tensor]}, 'questions': {'size': [-1, -1, -1], 'type': [torch.Tensor]}, 'targets': {'size': [-1, 1], 'type': [torch.Tensor]} }
[docs] def build_coord_tensor(self, batch_size, d): """ Create the tensor containing the spatial relative coordinate of each \ region (1 pixel) in the feature maps of the ``ConvInputModel``. These \ spatial relative coordinates are used to 'tag' the regions. :param batch_size: batch size :type batch_size: int :param d: size of 1 feature map :type d: int :return: tensor of shape [batch_size x d x d x 2] """ coords = torch.linspace(-1 / 2., 1 / 2., d) x = coords.unsqueeze(0).repeat(d, 1) y = coords.unsqueeze(1).repeat(1, d) ct = torch.stack((x, y)).type(self.app_state.dtype) # [2 x d x d] # broadcast to all batches # [batch_size x 2 x d x d] ct = ct.unsqueeze(0).repeat(batch_size, 1, 1, 1) # indicate that we do not track gradient for this tensor ct.requires_grad = False return ct
[docs] def forward(self, data_dict): """ Runs the ``RelationalNetwork`` model. :param data_dict: DataDict({'images', 'questions', ...}) containing: - images [batch_size, num_channels, height, width], - questions [batch_size, question_size] :type data_dict: utils.DataDict :returns: Predictions of the model [batch_size, nb_classes] """ images = data_dict['images'].type(self.app_state.dtype) questions = data_dict['questions'] question_size = questions.shape[-1] # step 1 : encode images feature_maps = self.cnn_model(images) batch_size = feature_maps.shape[0] # number of kernels in the final convolutional layer k = feature_maps.shape[1] d = feature_maps.shape[2] # size of 1 feature map # step 2: 'tag' all regions in feature_maps with their relative spatial # coordinates ct = self.build_coord_tensor(batch_size, d) # [batch_size x 2 x d x d] x_ct = torch.cat([feature_maps, ct], 1) # [batch_size x (k+2) x d x d] # update number of channels k += 2 # step 3: form all possible pairs of region in feature_maps (d** 2 regions -> d ** 4 pairs!) # flatten out feature_maps: [batch_size x k x d x d] -> [batch_size x k # x (d ** 2)] x_ct = x_ct.view(batch_size, k, d**2) x_ct = x_ct.transpose(2, 1) # [batch_size x (d ** 2) x k] x_i = x_ct.unsqueeze(1) # [batch_size x 1 x (d ** 2) x k] # [batch_size x (d ** 2) x (d ** 2) x k] x_i = x_i.repeat(1, (d**2), 1, 1) # step 4: add the question everywhere questions = questions.unsqueeze(1).repeat( 1, d ** 2, 1) # [batch_size, (d**2), question_size] # [batch_size, (d**2), 1, question_size] questions = questions.unsqueeze(2) x_j = x_ct.unsqueeze(2) # [batch_size x (d ** 2) x 1 x k] # [batch_size x (d ** 2) x 1 x (k+qst_size)] x_j = torch.cat([x_j, questions], dim=-1) # [batch_size x (d ** 2) x (d ** 2) x (k+qst_size)] x_j = x_j.repeat(1, 1, (d**2), 1) # generate all pairs # [batch_size, (d**2), (d**2), 2*k+qst_size] x = torch.cat([x_i, x_j], dim=-1) # step 5: pass pairs through pair_network # reshape for passing through network input_size = 2 * k + question_size x = x.view(batch_size * (d ** 4), input_size) x_g = self.pair_network(x) # reshape again & element-wise sum on the second dimension x_g = x_g.view(batch_size, (d ** 4), 256) x_f = x_g.sum(1) # step 6: pass sum of pairs through sum_network x_out = self.sum_network(x_f) return x_out
if __name__ == '__main__': """Unit test for the RelationalNetwork on SortOfCLEVR""" from miprometheus.utils.app_state import AppState from miprometheus.utils.param_interface import ParamInterface from torch.utils.data import DataLoader app_state = AppState() from miprometheus.problems.image_text_to_class.sort_of_clevr import SortOfCLEVR problem_params = ParamInterface() problem_params.add_config_params({'data_folder': '~/data/sort-of-clevr/', 'split': 'train', 'regenerate': False, 'dataset_size': 10000, 'img_size': 128}) # create problem sort_of_clevr = SortOfCLEVR(problem_params) print('Problem {} instantiated.'.format(sort_of_clevr.name)) # instantiate DataLoader object batch_size = 64 problem = DataLoader(sort_of_clevr, batch_size=batch_size, collate_fn=sort_of_clevr.collate_fn) model_params = ParamInterface() model_params.add_config_params({}) model = RelationalNetwork(model_params, sort_of_clevr.default_values) print('Model {} instantiated.'.format(model.name)) model.app_state.visualize = True # perform handshaking between RN & SortOfCLEVR model.handshake_definitions(sort_of_clevr.data_definitions) # generate a batch for i_batch, sample in enumerate(problem): print('Sample # {} - {}'.format(i_batch, sample['images'].shape), type(sample)) logits = model(sample) sort_of_clevr.plot_preprocessing(sample, logits) model.plot(sample, logits) print(logits.shape) print('Unit test completed.')