1

I am trying to get neighbours of a cell of matrix in pytorch using below part of code. it works correctly but it is very time consumming. Have you any suggestion to to get it faster

def neighbour(x):
    result=F.pad(input=x, pad=(1, 1, 1, 1), mode='constant', value=0)
    for m in range(1,x.size(0)+1):
        for n in range(1,x.size(1)+1):
                y=torch.Tensor([result[m][n],result[m-1][n-1],result[m-1][n],result[m-1] 
           [n+1],result[m][n-1],result[m][n+1],result[m+1][n-1],result[m+1][n],result[m+1][n+1]])
                x[m-1][n-1]=y.mean()

    return x
atishn90
  • 25
  • 4

1 Answers1

2

If you are only after the mean of the 9 elements centered at each pixel, then your best option would be to use a 2D convolution with a constant 3x3 filter:

import torch.nn.functional as nnf

def mean_filter(x_bchw):
  """
  Calculating the mean of each 3x3 neighborhood.
  input:
    - x_bchw: input tensor of dimensions batch-channel-height-width
  output:
    - y_bchw: each element in y is the average of the 9 corresponding elements in x_bchw
  """
  # define the filter
  box = torch.ones((3, 3), dtype=x_bchw.dtype, device=x_bchw.device, requires_grad=False)  
  box = box / box.sum()
  box = box[None, None, ...].repeat(x_bchw.size(1), 1, 1, 1)
  # use grouped convolution - so each channel is averaged separately.  
  y_bchw = nnf.conv2d(x_bchw, box, padding=1, groups=x_bchw.size(1))
  return y_bchw

however, if you want to apply a more elaborate function over each neighborhood, you may want to use nn.Unfold. This operation converts each 3x3 (or whatever rectangular neighborhood you define) to a vector. Once you have all the vectors you may apply your function to them.
See this answer for more details on unfold and fold.

Shai
  • 102,241
  • 35
  • 217
  • 344