#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
stacked_attention_model.py: Implementation of a Stacked Attention Network (SAN). \
Inspiration drawn partially from the following paper:
@article{DBLP:journals/corr/YangHGDS15,
author = {Zichao Yang and
Xiaodong He and
Jianfeng Gao and
Li Deng and
Alexander J. Smola},
title = {Stacked Attention Networks for Image Question Answering},
journal = {CoRR},
volume = {abs/1511.02274},
year = {2015},
url = {http://arxiv.org/abs/1511.02274},
archivePrefix = {arXiv},
eprint = {1511.02274},
timestamp = {Mon, 13 Aug 2018 16:47:25 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/YangHGDS15},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
__author__ = "Vincent Marois & Younes Bouhadjar"
import torch
import numpy as np
from miprometheus.models.model import Model
from miprometheus.models.vqa_baselines.stacked_attention_networks.stacked_attention_layer import StackedAttentionLayer
[docs]class StackedAttentionNetwork(Model):
"""
Implementation of a Stacked Attention Networks (SAN).
The three major components of SAN are:
- the image model (CNN model, possibly pretrained),
- the question model (LSTM based),
- the stacked attention model.
.. warning::
This implementation has only been tested on ``SortOfCLEVR`` so far.
"""
[docs] def __init__(self, params, problem_default_values_):
"""
Constructor class of ``StackedAttentionNetwork`` model.
- Parses the parameters,
- Instantiates the CNN model: A simple, 4-layers one, or a pretrained one,
- Instantiates an LSTM for the questions encoding,
- Instantiates a 3-layers MLP as classifier.
:param params: dict of parameters (read from configuration ``.yaml`` file).
:type params: utils.ParamInterface
:param problem_default_values_: default values coming from the ``Problem`` class.
:type problem_default_values_: dict
"""
# call base constructor
super(StackedAttentionNetwork, self).__init__(params, problem_default_values_)
# Parse default values received from problem.
try:
self.height = problem_default_values_['height']
self.width = problem_default_values_['width']
self.num_channels = problem_default_values_['num_channels'] # number of channels
self.question_encoding_size = problem_default_values_['question_size']
self.nb_classes = problem_default_values_['num_classes']
except KeyError:
self.logger.warning("Couldn't retrieve one or more value(s) from problem_default_values_.")
self.name = 'StackedAttentionNetwork'
# Instantiate CNN for image encoding
if params['use_pretrained_cnn']:
from miprometheus.models.vqa_baselines import PretrainedImageEncoding
self.cnn = PretrainedImageEncoding(params['pretrained_cnn']['name'], params['pretrained_cnn']['num_layers'])
self.image_encoding_channels = self.cnn.get_output_nb_filters()
else:
from miprometheus.models import ConvInputModel
self.cnn = ConvInputModel()
self.image_encoding_channels = self.cnn.get_output_nb_filters()
# Instantiate LSTM for question encoding
self.hidden_size = params['lstm']['hidden_size']
self.num_layers = params['lstm']['num_layers']
self.bidirectional = params['lstm']['bidirectional']
if self.bidirectional:
self.num_dir = 2
else:
self.num_dir = 1
self.lstm =torch.nn.LSTM(input_size=self.question_encoding_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
bias=True,
batch_first=True,
dropout=params['lstm']['dropout'],
bidirectional=self.bidirectional)
output_question_dim = self.num_dir * self.hidden_size
# Retrieve attention layer parameters
self.mid_features_attention = params['attention_layer']['nb_nodes']
# Question encoding
self.ffn =torch.nn.Linear(in_features=output_question_dim, out_features=self.image_encoding_channels)
# Instantiate class for attention
self.apply_attention = StackedAttentionLayer(question_image_encoding_size=self.image_encoding_channels,
key_query_size=self.mid_features_attention)
# Instantiate MLP for classifier
input_size = self.image_encoding_channels
self.fc1 =torch.nn.Linear(in_features=input_size, out_features=params['classifier']['nb_hidden_nodes'])
self.fc2 =torch.nn.Linear(params['classifier']['nb_hidden_nodes'], params['classifier']['nb_hidden_nodes'])
self.fc3 =torch.nn.Linear(params['classifier']['nb_hidden_nodes'], self.nb_classes)
self.data_definitions = {'images': {'size': [-1, self.num_channels, self.height, self.width], 'type': [torch.Tensor]},
'questions': {'size': [-1, self.question_encoding_size], 'type': [torch.Tensor]},
'targets': {'size': [-1, self.nb_classes], 'type': [torch.Tensor]}
}
[docs] def forward(self, data_dict):
"""
Runs the ``StackedAttentionNetwork`` model.
:param data_dict: DataDict({'images', 'questions', ...}) where:
- images: [batch_size, num_channels, height, width],
- questions: [batch_size, size_question_encoding]
:type data_dict: utils.DataDict
:returns: Predictions: [batch_size, output_classes]
"""
images = data_dict['images'].type(self.app_state.dtype)
questions = data_dict['questions']
# 1. Encode the images
encoded_images = self.cnn(images)
# flatten the images
encoded_images = encoded_images.view(encoded_images.size(0), encoded_images.size(1), -1).transpose(1, 2)
# 2. Encode the questions
# (h_0, c_0) are not provided -> default to zero
encoded_question, _ = self.lstm(questions.unsqueeze(1))
# take layer's last output
encoded_question = encoded_question[:, -1, :]
# 3. Go through the ``StackedAttentionLayer``.
encoded_question = self.ffn(encoded_question)
encoded_attention = self.apply_attention(encoded_images, encoded_question)
# 4. Classify based on the result of the stacked attention layer
x = torch.nn.functional.relu(self.fc1(encoded_attention))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.dropout(x) # p=0.5
logits = self.fc3(x)
return logits
[docs] def plot(self, data_dict, predictions, sample=0):
"""
Displays the image, the predicted & ground truth answers.
:param data_dict: DataDict({'images', 'questions', 'targets'}) where:
- images: [batch_size, num_channels, height, width],
- questions: [batch_size, size_question_encoding]
- targets: [batch_size]
:type data_dict: utils.DataDict
:param predictions: Prediction.
:type predictions: torch.tensor
:param sample: Index of sample in batch (DEFAULT: 0).
:type sample: int
"""
# Check if we are supposed to visualize at all.
if not self.app_state.visualize:
return False
import matplotlib.pyplot as plt
images = data_dict['images']
questions = data_dict['questions']
targets = data_dict['targets']
# Get sample.
image = images[sample]
target = targets[sample]
prediction = np.argmax(predictions[sample].detach().numpy())
question = questions[sample]
# Show data.
plt.title('Prediction: {} (Target: {})'.format(prediction, target))
plt.xlabel('Q: {} )'.format(question))
plt.imshow(image.permute(1, 2, 0),
interpolation='nearest', aspect='auto')
_ = plt.figure()
plt.title('Attention')
width_height_attention = int(
np.sqrt(self.apply_attention.visualize_attention.size(-2)))
# get the attention of the 2 layers of stacked attention
attention_visualize_layer1 = self.apply_attention.visualize_attention[sample, :, 0].detach(
).numpy()
attention_visualize_layer2 = self.apply_attention.visualize_attention[sample, :, 1].detach(
).numpy()
# reshape to get a 2D plot
attention_visualize_layer1 = attention_visualize_layer1.reshape(
width_height_attention, width_height_attention)
attention_visualize_layer2 = attention_visualize_layer2.reshape(
width_height_attention, width_height_attention)
plt.title('1st attention layer')
plt.imshow(attention_visualize_layer1,
interpolation='nearest', aspect='auto')
_ = plt.figure()
plt.title('2nd attention layer')
plt.imshow(attention_visualize_layer2,
interpolation='nearest', aspect='auto')
# Plot!
plt.show()
if __name__ == '__main__':
""" Tests StackedAttentionNetwork on SortOfCLEVR"""
# "Loaded parameters".
from miprometheus.utils.param_interface import ParamInterface
from miprometheus.utils.app_state import AppState
app_state = AppState()
app_state.visualize = True
from miprometheus.problems.image_text_to_class.sort_of_clevr import SortOfCLEVR
problem_params = ParamInterface()
problem_params.add_config_params({'data_folder': '~/data/sort-of-clevr/',
'split': 'train',
'regenerate': False,
'dataset_size': 10000,
'img_size': 128})
# create problem
sortofclevr = SortOfCLEVR(problem_params)
batch_size = 64
# wrap DataLoader on top of this Dataset subclass
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset=sortofclevr, collate_fn=sortofclevr.collate_fn,
batch_size=batch_size, shuffle=True, num_workers=4)
model_params = ParamInterface()
model_params.add_config_params({'use_pretrained_cnn': False,
'pretrained_cnn': {'name': 'resnet18', 'num_layers': 2},
'lstm': {'hidden_size': 64, 'num_layers': 1, 'bidirectional': False,
'dropout': 0},
'attention_layer': {'nb_nodes': 128},
'classifier': {'nb_hidden_nodes': 256}})
# create model
model = StackedAttentionNetwork(model_params, sortofclevr.default_values)
for batch in dataloader:
logits = model(batch)
print(logits.shape)
if model.plot(batch, logits):
break