#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
stacked_attention_layer: Implements the Attention Layer as described in section 3.3 of the following paper:
@article{DBLP:journals/corr/YangHGDS15,
author = {Zichao Yang and
Xiaodong He and
Jianfeng Gao and
Li Deng and
Alexander J. Smola},
title = {Stacked Attention Networks for Image Question Answering},
journal = {CoRR},
volume = {abs/1511.02274},
year = {2015},
url = {http://arxiv.org/abs/1511.02274},
archivePrefix = {arXiv},
eprint = {1511.02274},
timestamp = {Mon, 13 Aug 2018 16:47:25 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/YangHGDS15},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
__author__ = "Vincent Marois & Younes Bouhadjar"
import torch
from torch.nn import Module
from miprometheus.utils.app_state import AppState
[docs]class StackedAttentionLayer(Module):
"""
Stacks several layers of ``Attention`` to enable multi-step reasoning.
"""
[docs] def __init__(self, question_image_encoding_size,
key_query_size, num_att_layers=2):
"""
Constructor of the ``StackedAttentionLayers`` class.
:param question_image_encoding_size: Size of the images & questions encoding.
:type question_image_encoding_size: int
:param key_query_size: Size of the Key & Query, considered the same for both in this implementation.
:type key_query_size: int
:param num_att_layers: Number of ``AttentionLayer`` to use.
:type num_att_layers: int
"""
# call base constructor
super(StackedAttentionLayer, self).__init__()
# to visualize attention
self.visualize_attention = None
self.san = torch.nn.ModuleList([AttentionLayer(question_image_encoding_size, key_query_size)] * num_att_layers)
[docs] def forward(self, encoded_image, encoded_question):
"""
Apply stacked attention.
:param encoded_image: output of the image encoding (CNN + FC layer), should be of shape \
[batch_size, width * height, num_channels_encoded_image]
:type encoded_image: torch.tensor
:param encoded_question: Last hidden layer of the LSTM, of shape [batch_size, question_encoding_size]
:type encoded_question: torch.tensor
:return: u: attention [batch_size, num_channels_encoded_image]
"""
for att_layer in self.san:
u, attention_prob = att_layer(encoded_image, encoded_question)
if AppState().visualize:
if self.visualize_attention is None:
self.visualize_attention = attention_prob
# Concatenate output
else:
self.visualize_attention = torch.cat([self.visualize_attention, attention_prob], dim=-1)
return u
[docs]class AttentionLayer(Module):
"""
Implements one layer of the Stacked Attention mechanism.
Reference: Section 3.3 of the paper cited above.
"""
[docs] def __init__(self, question_image_encoding_size, key_query_size=512):
"""
Constructor of the ``AttentionLayer`` class.
:param question_image_encoding_size: Size of the images & questions encoding.
:type question_image_encoding_size: int
:param key_query_size: Size of the Key & Query, considered the same for both in this implementation.
:type key_query_size: int
"""
# call base constructor
super(AttentionLayer, self).__init__()
# fully connected layer to construct the key
self.ff_image = torch.nn.Linear(question_image_encoding_size, key_query_size)
# fully connected layer to construct the query
self.ff_ques = torch.nn.Linear(question_image_encoding_size, key_query_size)
# fully connected layer to construct the attention from the query and key
self.ff_attention = torch.nn.Linear(key_query_size, 1)
[docs] def forward(self, encoded_image, encoded_question):
"""
Applies one layer of stacked attention over the image & question.
:param encoded_image: output of the image encoding (CNN + FC layer), should be of shape \
[batch_size, width * height, num_channels_encoded_image]
:type encoded_image: torch.tensor
:param encoded_question: Last hidden layer of the LSTM, of shape [batch_size, question_encoding_size]
:type encoded_question: torch.tensor
:returns:
- "Refined query vector" (weighted sum of the image vectors, combine with the question vector), \
of shape [batch_size, num_channels_encoded_image]
- Attention weights, todo: shape?
"""
# Get the key
key = self.ff_image(encoded_image)
# Get the query, unsqueeze to be able to add the query to all channel
query = self.ff_ques(encoded_question).unsqueeze(dim=1)
weighted_key_query = torch.nn.functional.tanh(key + query)
# Get attention over the different layers
weighted_key_query = self.ff_attention(weighted_key_query)
attention_prob = torch.nn.functional.softmax(weighted_key_query, dim=-2)
vi_attended = (attention_prob * encoded_image).sum(dim=1)
u = vi_attended + encoded_question
return u, attention_prob