0

Given a tensor shape (3, 256, 256). I would like to convolute or loop through it pixel by pixel to return me a tensor shape (1, 256, 256).

This may sound a bit confusing so here is my code till now so you know I mean.

class MorphNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(3, 8)
        self.fc2 = nn.Linear(8, 1)

    def forward(self, x):
        # The input here is shape (3, 256, 256)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Returned shape should be (1, 256, 256)
        return x

As you can see my Linear layer accept's shape 3 which matches the depth of my original tensor. What is the best way of looping through all 256x256 to return me tensor shape (1, 256, 256)

Shai
  • 102,241
  • 35
  • 217
  • 344
Ahmed Khalf
  • 106
  • 1
  • 8

1 Answers1

1

A linear layer that takes 3dim input and outputs 8dim is mathematically equivalent to a convolution with a kernel of spatial size of 1x1 (I strongly recommend that you actually "do the math" and convince yourself that this is indeed correct).

Therefore, you can use the following model, replacing the linear layers with nn.Conv2D:

class MorphNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.c1 = nn.Conv2d(3, 8, kernel_size=1, bias=True)
        self.c2 = nn.Conv2d(8, 1, kernel_size=1, bias=True)

    def forward(self, x):
        # The input here is shape (3, 256, 256)
        x = F.relu(self.c1(x))
        x = self.c2(x)
        # Returned shape should be (1, 256, 256)
        return x

If you insist on using a nn.Linear layer, you can unfold your input and then unfold it back after you apply the linear layer.

Shai
  • 102,241
  • 35
  • 217
  • 344