Source code for miprometheus.models.vision.lenet5

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2018
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
lenet5: a classical LeNet-5 model for MNIST digit classification. \
 To be taken as an illustrative example.
 """

__author__ = "Tomasz Kornuta & Vincent Marois"


import torch

# Import useful MI-Prometheus classes.
from miprometheus.models.model import Model


[docs]class LeNet5(Model): """ A classical LeNet-5 model for MNIST digits classification. """
[docs] def __init__(self, params_, problem_default_values_): """ Initializes the ``LeNet5`` model, creates the required layers. :param params: Parameters read from configuration file. :type params: ``miprometheus.utils.ParamInterface`` :param problem_default_values_: dict of parameters values coming from the problem class. :type problem_default_values_: dict """ super(LeNet5, self).__init__(params_, problem_default_values_) self.name = 'LeNet5' # Create the LeNet-5 layers. self.conv1 = torch.nn.Conv2d(1, 6, kernel_size=(5, 5)) self.maxpool1 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) self.conv2 = torch.nn.Conv2d(6, 16, kernel_size=(5, 5)) self.maxpool2 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2) self.conv3 = torch.nn.Conv2d(16, 120, kernel_size=(5, 5)) self.linear1 = torch.nn.Linear(120, 84) self.linear2 = torch.nn.Linear(84, 10) # Create Model data definitions - indicate what a given model needs. self.data_definitions = { 'images': {'size': [-1, 1, 32, 32], 'type': [torch.Tensor]}}
[docs] def forward(self, data_dict): """ Main forward pass of the ``LeNet5`` model. :param data_dict: DataDict({'images',**}), where: - images: [batch_size, num_channels, width, height] :type data_dict: ``miprometheus.utils.DataDict`` :return: Predictions [batch_size, num_classes] """ # Unpack DataDict. img = data_dict['images'] # Pass inputs through layers. x = self.conv1(img) x = torch.nn.functional.relu(x) x = self.maxpool1(x) x = self.conv2(x) x = torch.nn.functional.relu(x) x = self.maxpool2(x) x = self.conv3(x) x = torch.nn.functional.relu(x) x = x.view(-1, 120) x = self.linear1(x) x = torch.nn.functional.relu(x) x = self.linear2(x) return x # return logits.