10

In this example network from pyTorch tutorial

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

net = Net()
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

Why is the method forward() not explicitely called? I mean how does just calling net(output) calls forward() ? (which is what happens as far as I understand) By the way I dont understand what this line means:

super(Net, self).__init__()

I can imagine super() is calling the constructor of a parent class but …?

SheppLogan
  • 322
  • 4
  • 11

1 Answers1

13

If you look at the Module implementation of pyTorch, you'll see that forward is a method called in the special method __call__ :

class Module(object):
   ...
   def __call__(self, *input, **kwargs):
      ...
      result = self.forward(*input, **kwargs)

As you construct a Net class by inheriting from the Module class and you override the default behavior of the __init__ constructor, you also need to explicitly call the parent's one with super(Net, self).__init__().

Elliot
  • 1,081
  • 7
  • 13
  • Thanks, great answer. This is something they could have put in the comments. – SheppLogan Aug 30 '19 at 12:02
  • But so the _call_ merhod is called by the constructor of the Module class if i understood right? – SheppLogan Aug 30 '19 at 12:05
  • 2
    Nope, it is called when you call an instance of the class, so in the example : out = net(input). See : https://stackoverflow.com/questions/9663562/what-is-the-difference-between-init-and-call – Elliot Aug 30 '19 at 12:08
  • I dont see any explicit call to \call\ if i can say so lol – SheppLogan Aug 30 '19 at 12:09
  • Yes, it is defined by Python's internal implementation. You may learn more on https://eli.thegreenplace.net/2012/03/23/python-internals-how-callables-work/ ! – Elliot Aug 30 '19 at 12:13
  • 1
    great thanks for your fantastic answers! – SheppLogan Aug 30 '19 at 13:21