2

I have two tensors that should together overlap each other to form a larger tensor. To illustrate:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]])
b = torch.Tensor([[5, 6, 7], [5, 6, 7]])

a = [[1 2 3]    b = [[5 6 7]
     [1 2 3]]        [5 6 7]]

I want to combine the two tensors and have them partially overlap by a single column, with the average being taken for those elements that overlap.

e.g.

result = [[1 2 4 6 7]
          [1 2 4 6 7]]

The first two columns are the first two columns of 'a'. The last two columns are the last two columns of 'b'. The middle column is the average of 'a's last column and 'b's first column.

I know how to merge two tensors side by side or in a new dimension. But doing this eludes me.

Can anyone help?

Avi Chapman
  • 319
  • 1
  • 8
  • 1
    Can you explain a bit more what do you mean by partial overlap? the example you have given is not completely obvious. E.g. the result has dimension `2x5` from combination of two tensor of dimensions `2x3` is not very natural. – asymptote Aug 27 '19 at 02:13

1 Answers1

5

This is not a trivial operation, and this solution is not very trivial or intuitive either.

Looking at result with shape=(2, 5), you can think of a and b as two 2x3 patches of result taken with stride=2. Like this illustration:
enter image description here

We can use pytorch's unfold to "recover" the green (a) and blue (b) patches from result ("recover" up to the averaged values):

from torch.nn import functional as nnf

recovered = nnf.unfold(result, kernel_size=(2,3), stride=2)

The result is:

tensor([[[1., 4.],
         [2., 6.],
         [4., 7.],
         [1., 4.],
         [2., 6.],
         [4., 7.]]])

The patches were recovered (as column vectors).

Now that we understand how to get a and b from result, we can use fold to perform the "inverse" operation and go from b and b to result.
First, we need to flatten concatenate a and b to the shape fold expects (mimicking the output of unfold, two "flatten" patches of 3x2 elements):

uf = torch.cat((a.view(1, 6, 1), b.view(1, 6, 1)), dim=2)

We can now "fold" the patches

raw = nnf.fold(uf, (2,5), kernel_size=(2,3), stride=2)

We are not there yet, when there are overlapping elements fold sums up the overlapping elements, resulting with

tensor([[[[1., 2., 8., 6., 7.],
          [1., 2., 8., 6., 7.]]]])

To count how many elements were summed for each entry in result, we can simply "fold" an all ones tensor

counter = nnf.fold(torch.ones_like(uf), (2, 5), kernel_size=(2, 3), stride=2)

And finally, we can recover result:

result = raw / counter
tensor([[[[1., 2., 4., 6., 7.],
          [1., 2., 4., 6., 7.]]]])

Piece of cake.

Shai
  • 102,241
  • 35
  • 217
  • 344
  • I followed your first block of code to get `recovered`, then I already get `NotImplementedError: Input Error: Only 4D input Tensors are supported (got 2D)`. I am using PyTorch 1.6.0 – Raven Cheuk Oct 24 '20 at 03:43