0

Input

I have torch tensor as fallow.

The shape for this input_tensor is torch.size([4,4])

input_tensor = 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

I'm going to create a tensor that stacks up the tensor that comes out of the above input_tensor by moving the Window in size (2,2).

output

my desired output is as follow

The shape for this output_tensor is torch.size([8,2])

output = 
tensor([[ 0,  1],
        [ 4,  5],
        [ 2,  3],
        [ 6,  7],
        [ 8,  9],
        [12, 13],
        [10, 11],
        [14, 15]])

My code is as follows.

x = torch.chunk(input_tensor, chunks=2, dim=0)
x = list(x)
for i, t in enumerate(x):
    x[i] = torch.cat(torch.chunk(t, chunks=2 ,dim=1))
output_tensor = torch.cat(x)

Is there a simpler or easier way to get the result I want?

Won chul Shin
  • 143
  • 1
  • 7

2 Answers2

1

You can use torch.split() together with torch.cat() as follows:

output_tensor = torch.cat(torch.split(input_tensor, 2, dim=1))

The ouput with be:

output = 
tensor([[ 0,  1],
        [ 4,  5],
        [ 8,  9],
        [12, 13],
        [ 2,  3],
        [ 6,  7],
        [10, 11],
        [14, 15]])
Maryam Bahrami
  • 801
  • 7
  • 17
0

You are looking at unfolding the tensor:

import torch
import torch.nn.functional as nnf

input_tensor = torch.arange(16.).view(1, 1, 4, 4)
nnf.unfold(input_tensor, kernel_size=2, stride=2, padding=0).T.reshape(8,2)

More on unfolding and folding can be found here.

Shai
  • 102,241
  • 35
  • 217
  • 344