1

Let's say I have a tensor shaped (1, 64, 128, 128) and I want to create a tensor of shape (1, 64, 255) holding the sums of all diagonals for every (128, 128) matrix (there are 1 main, 127 below, 127 above diagonals so in total 255). What I am currently doing is the following:

x = torch.rand(1, 64, 128, 128)

diag_sums = torch.zeros(1, 64, 255)
j = 0
for k in range(-127, 128):
    diag_sums[j, :, k + 127] = torch.diagonal(x, offset=k, dim1=-2, dim2=-1).sum(dim=2)
This is obviously very slow, since it is using Python loops and is not done in parallel with respect to k.

I don't think this can be done using torch.diagonal since the function explicitly uses a single int for the offset parameter. If I could pass a list there, this would work, but I guess it would be complicated to implement (requiring changes in PyTorch itself).

I think it could be possible to implement this using torch.einsum, but I cannot think of a way to do it.

So this is my question: how do I get the tensor described above?

Hristo Vrigazov
  • 1,267
  • 2
  • 11
  • 19

2 Answers2

1

Have you considered using torch.nn.functional.conv2d?
You can sum the diagonals with a diagonal filter sliding across the tensor with appropriate zero padding.

import torch
import torch.nn.functional as nnf

# construct a diagonal filter using `eye` function, shape it appropriately
f = torch.eye(x.shape[2])[None, None,...].repeat(x.shape[1], 1, 1, 1)
# compute the diagonal sum with appropriate zero padding
conv_diag_sums = nnf.conv2d(x, f, padding=(x.shape[2]-1,0), groups=x.shape[1])[..., 0]

Note the the result has a slightly different order than the one you computed in the loop:

diag_sums = torch.zeros(1, 64, 255)
for k in range(-127, 128):
    diag_sums[j, :, 127-k] = torch.diagonal(x, offset=k, dim1=-2, dim2=-1).sum(dim=2)

# compare
(conv_diag_sums == diag_sums).all()

results with True - they are the same.

Shai
  • 102,241
  • 35
  • 217
  • 344
1

Shai's answer works, however it looks like it has a lot of multiplications, due to the large size of the kernel. I figured out a way to do this for my use case. It is based on this answer for a similar question in Numpy: https://stackoverflow.com/a/35074207/6636290

I am doing the following:

digitized = np.sum(np.indices(a.shape), axis=0).ravel()
digitized_tensor = torch.Tensor(digitized).int()
a_tensor = torch.Tensor(a)
torch.bincount(digitized_tensor, a_tensor.view(-1))

If I could figure out a way to do this entirely in PyTorch (without Numpy's indices function), this would be great, but this answers the question.

Hristo Vrigazov
  • 1,267
  • 2
  • 11
  • 19