Source code for miprometheus.models.vqa_baselines.stacked_attention_networks.multi_hops_stacked_attention_model

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
multi_hops_stacked_attention_model.py: Implementation of a Stacked Attention Network (SAN).

This is a variant of the Stacked Attention Network model, in which several attention hops are done over the \
question words.

Inspiration drawn partially from the following paper:

@article{DBLP:journals/corr/KazemiE17,
  author    = {Vahid Kazemi and
               Ali Elqursh},
  title     = {Show, Ask, Attend, and Answer: {A} Strong Baseline For Visual Question
               Answering},
  journal   = {CoRR},
  volume    = {abs/1704.03162},
  year      = {2017},
  url       = {http://arxiv.org/abs/1704.03162},
  archivePrefix = {arXiv},
  eprint    = {1704.03162},
  timestamp = {Mon, 13 Aug 2018 16:47:10 +0200},
  biburl    = {https://dblp.org/rec/bib/journals/corr/KazemiE17},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

"""
__author__ = "Vincent Marois & Younes Bouhadjar"

import torch
import numpy as np
import torch.nn as nn

from miprometheus.models.model import Model
from miprometheus.models.vqa_baselines.stacked_attention_networks.stacked_attention_layer import StackedAttentionLayer


[docs]class MultiHopsStackedAttentionNetwork(Model): """ Implementation of a Stacked Attention Networks (SAN), with several attention hops over the \ question words. The implementation details are very similar to the `StackedAttentionNetwork``, to the difference that \ it uses an LSTMCell instead of an LSTM. .. warning:: This implementation has only been tested on ``ShapeColorQuery`` so far. """
[docs] def __init__(self, params, problem_default_values_): """ Constructor class of ``MultiHopsStackedAttentionNetwork`` model. - Parses the parameters, - Instantiates the CNN model: A simple, 4-layers one, or a pretrained one, - Instantiates an LSTMCell for the questions encoding, - Instantiates a 3-layers MLP as classifier. :param params: dict of parameters (read from configuration ``.yaml`` file). :type params: utils.ParamInterface :param problem_default_values_: default values coming from the ``Problem`` class. :type problem_default_values_: dict """ # call base constructor super(MultiHopsStackedAttentionNetwork, self).__init__(params, problem_default_values_) # Parse default values received from problem. try: self.height = problem_default_values_['height'] self.width = problem_default_values_['width'] self.num_channels = problem_default_values_['num_channels'] # number of channels self.question_encoding_size = problem_default_values_['question_size'] self.nb_classes = problem_default_values_['num_classes'] self.num_words = problem_default_values_['seq_length'] except KeyError: self.num_words = params['default_nb_hops'] self.logger.warning("Couldn't retrieve one or more value(s) from problem_default_values_.") self.name = 'MultiHopsStackedAttentionNetwork' # Instantiate CNN for image encoding if params['use_pretrained_cnn']: from miprometheus.models.vqa_baselines import PretrainedImageEncoding self.cnn = PretrainedImageEncoding(params['pretrained_cnn']['name'], params['pretrained_cnn']['num_layers']) self.image_encoding_channels = self.cnn.get_output_nb_filters() else: from miprometheus.models import ConvInputModel self.cnn = ConvInputModel() self.image_encoding_channels = self.cnn.get_output_nb_filters() # Instantiate LSTM for question encoding self.hidden_size = self.image_encoding_channels self.lstm = nn.LSTMCell(input_size=self.question_encoding_size, hidden_size=self.hidden_size, bias=True) # Retrieve attention layer parameters self.mid_features_attention = params['attention_layer']['nb_nodes'] # Instantiate class for attention self.apply_attention = StackedAttentionLayer(question_image_encoding_size=self.image_encoding_channels, key_query_size=self.mid_features_attention) # Instantiate MLP for classifier input_size = (self.num_words + 1) *self.image_encoding_channels self.fc1 = nn.Linear(in_features=input_size, out_features=params['classifier']['nb_hidden_nodes']) self.fc2 = nn.Linear(params['classifier']['nb_hidden_nodes'], params['classifier']['nb_hidden_nodes']) self.fc3 = nn.Linear(params['classifier']['nb_hidden_nodes'], self.nb_classes) self.data_definitions = { 'images': {'size': [-1, self.num_channels, self.height, self.width], 'type': [torch.Tensor]}, 'questions': {'size': [-1, 3, self.question_encoding_size], 'type': [torch.Tensor]}, 'targets': {'size': [-1, self.nb_classes], 'type': [torch.Tensor]} }
[docs] def init_hidden_states(self, batch_size): """ Initialize the hidden and cell states of the LSTM to 0. :param batch_size: Size of the batch. :type batch_size: int :return: hx, cx: hidden and cell states initialized to 0. """ hx = torch.zeros(batch_size, self.hidden_size).type(self.app_state.dtype) cx = torch.zeros(batch_size, self.hidden_size).type(self.app_state.dtype) return hx, cx
[docs] def forward(self, data_dict): """ Runs the ``MultiHopsStackedAttentionNetwork`` model. :param data_dict: DataDict({'images', 'questions', ...}) where: - images: [batch_size, num_channels, height, width], - questions: [batch_size, size_question_encoding] :type data_dict: utils.DataDict :returns: Predictions: [batch_size, output_classes] """ images = data_dict['images'].type(self.app_state.dtype) questions = data_dict['questions'] # get batch size batch_size = images.shape[0] # 1. Encode the images encoded_images = self.cnn(images) # flatten the images encoded_images = encoded_images.view(encoded_images.size(0), encoded_images.size(1), -1).transpose(1, 2) # 2. Encode the questions v_features = None # initialize the LSTM states hx, cx = self.init_hidden_states(batch_size) for i in range(questions.size(1)): hx, cx = self.lstm(questions[:, i, :], (hx, cx)) # 3. Go through the ``StackedAttentionLayer``. v = self.apply_attention(encoded_images, hx.squeeze(1)) if v_features is None: v_features = v else: v_features = torch.cat((v_features, v), dim=-1) # 4. Classify based on the result of the stacked attention layer combined = torch.cat([v_features, hx], dim=1) x = torch.nn.functional.relu(self.fc1(combined)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.dropout(x) # p=0.5 logits = self.fc3(x) return logits
[docs] def plot(self, data_dict, predictions, sample=0): """ Displays the image, the predicted & ground truth answers. :param data_dict: DataDict({'images', 'questions', 'targets'}) where: - images: [batch_size, num_channels, height, width], - questions: [batch_size, size_question_encoding] - targets: [batch_size] :type data_dict: utils.DataDict :param predictions: Prediction. :type predictions: torch.tensor :param sample: Index of sample in batch (DEFAULT: 0). :type sample: int """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False import matplotlib.pyplot as plt images = data_dict['images'] questions = data_dict['questions'] targets = data_dict['targets'] # Get sample. image = images[sample] target = targets[sample] prediction = np.argmax(predictions[sample].detach().numpy()) question = questions[sample] # Show data. plt.title('Prediction: {} (Target: {})'.format(prediction, target)) plt.xlabel('Q: {} )'.format(question)) print(type(image)) plt.imshow(image.permute(1, 2, 0), interpolation='nearest', aspect='auto') # Plot! plt.show()
if __name__ == '__main__': """ Tests MultiHopsStackedAttentionNetwork on ShapeColorQuery""" # "Loaded parameters". from miprometheus.utils.param_interface import ParamInterface from miprometheus.utils.app_state import AppState app_state = AppState() app_state.visualize = False from miprometheus.problems import ShapeColorQuery problem_params = ParamInterface() problem_params.add_config_params({'data_folder': '~/data/shape-color-query/', 'split': 'train', 'regenerate': False, 'dataset_size': 10000, 'img_size': 128}) # create problem shapecolorquery = ShapeColorQuery(problem_params) batch_size = 64 # wrap DataLoader on top of this Dataset subclass from torch.utils.data import DataLoader dataloader = DataLoader(dataset=shapecolorquery, collate_fn=shapecolorquery.collate_fn, batch_size=batch_size, shuffle=True, num_workers=4) model_params = ParamInterface() model_params.add_config_params({'use_pretrained_cnn': False, 'pretrained_cnn': {'name': 'resnet18', 'num_layers': 2}, 'attention_layer': {'nb_nodes': 128}, 'classifier': {'nb_hidden_nodes': 256}, 'default_nb_hops': 3}) # create model model = MultiHopsStackedAttentionNetwork(model_params, shapecolorquery.default_values) for batch in dataloader: logits = model(batch) print(logits.shape) if model.plot(batch, logits): break