15

I don't understand what squeeze and unsqueeze do to a tensor, even after looking at the docs and related questions.

I tried to understand it by exploring it myself in python. I first created a random tensor with

x = torch.rand(3,2,dtype=torch.float)
>>> x
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])

But regardless of how I squeeze it, I end up with the same results:

torch.equal(x.squeeze(0), x.squeeze(1))
>>> True

If I now try to unsqueeze I get the following,

>>> x.unsqueeze(1)
tensor([[[0.3703, 0.9588]],
        [[0.8064, 0.9716]],
        [[0.9585, 0.7860]]])
>>> x.unsqueeze(0)
tensor([[[0.3703, 0.9588],
         [0.8064, 0.9716],
         [0.9585, 0.7860]]])
>>> x.unsqueeze(-1)
tensor([[[0.3703],
         [0.9588]],
        [[0.8064],
         [0.9716]],
        [[0.9585],
         [0.7860]]])

However if I now create a tensor x = torch.tensor([1,2,3,4]), and I try to unsqueeze it then it appears that 1 and -1 makes it a column where as 0 remains the same.

x.unsqueeze(0)
tensor([[1, 2, 3, 4]])
>>> x.unsqueeze(1)
tensor([[1],
        [2],
        [3],
        [4]])
>>> x.unsqueeze(-1)
tensor([[1],
        [2],
        [3],
        [4]])

Can someone provide an explanation of what squeeze and unsqueeze are doing to a tensor? And what's the difference between providing the arguements 0, 1 and -1?

iacob
  • 14,010
  • 5
  • 54
  • 92
Mark Shaio
  • 608
  • 1
  • 6
  • 18
  • 1
    Does this answer your question? [What does "unsqueeze" do in Pytorch?](https://stackoverflow.com/questions/57237352/what-does-unsqueeze-do-in-pytorch) – iacob Jan 21 '21 at 01:28
  • Note: `-1` is just an alias for the final dimension, i.e. `1` in a 2d tensor. – iacob Feb 14 '21 at 17:49

3 Answers3

26

Here is a visual representation of what squeeze/unsqueeze do for an effectively 2d matrix:

enter image description here

When you are unsqueezing a tensor, it is ambiguous which dimension you wish to 'unsqueeze' it across (as a row or column etc). The dim argument dictates this - i.e. position of the new dimension to be added.

Hence the resulting unsqueezed tensors have the same information, but the indices used to access them are different.

iacob
  • 14,010
  • 5
  • 54
  • 92
13

Simply put, unsqueeze() "adds" a superficial 1 dimension to tensor (at the specified dimension), while squeeze removes all superficial 1 dimensions from tensor.

You should look at tensor's shape attribute to see it easily. In your last case it would be:

import torch

tensor = torch.tensor([1, 0, 2, 3, 4])
tensor.shape # torch.Size([5])
tensor.unsqueeze(dim=0).shape # [1, 5]
tensor.unsqueeze(dim=1).shape # [5, 1]

It is useful for providing single sample to the network (which requires first dimension to be batch), for images it would be:

# 3 channels, 32 width, 32 height
tensor = torch.randn(3, 32, 32)
# 1 batch, 3 channels, 32 width, 32 height
tensor.unsqueeze(dim=0).shape

unsqueeze can be seen if you create tensor with 1 dimensions, e.g. like this:

# 3 channels, 32 width, 32 height and some 1 unnecessary dimensions
tensor = torch.randn(3, 1, 32, 1, 32, 1)
# 1 batch, 3 channels, 32 width, 32 height again
tensor.squeeze().unsqueeze(0) # [1, 3, 32, 32]
iacob
  • 14,010
  • 5
  • 54
  • 92
Szymon Maszke
  • 19,592
  • 2
  • 33
  • 75
3
  1. torch.unsqueeze(input, dim)Tensor

    a = torch.randn(4, 4, 4)
    torch.unsqueeze(a, 0).size()
    
    >>> torch.Size([1, 4, 4, 4])
    
    a = torch.randn(4, 4, 4)
    torch.unsqueeze(a, 1).size()
    
    >>> torch.Size([4, 1, 4, 4])
    
    a = torch.randn(4, 4, 4)
    torch.unsqueeze(a, 2).size()
    
    >>> torch.Size([4, 4, 1, 4])
    
    a = torch.randn(4, 4, 4)
    torch.unsqueeze(a, 3).size()
    
    >>> torch.Size([4, 4, 4, 1])
    
  2. torch.squeeze(input, dim=None, out=None)Tensor

    b = torch.randn(4, 1, 4)
    
    >>> tensor([[[ 1.2912, -1.9050,  1.4771,  1.5517]],
    
            [[-0.3359, -0.2381, -0.3590,  0.0406]],
    
            [[-0.2460, -0.2326,  0.4511,  0.7255]],
    
            [[-0.1456, -0.0857, -0.8443,  1.1423]]])
    
    b.size()
    
    >>> torch.Size([4, 1, 4])
    
    
    c = b.squeeze(1)
    
    
    b
    >>> tensor([[[ 1.2912, -1.9050,  1.4771,  1.5517]],
    
            [[-0.3359, -0.2381, -0.3590,  0.0406]],
    
            [[-0.2460, -0.2326,  0.4511,  0.7255]],
    
            [[-0.1456, -0.0857, -0.8443,  1.1423]]])
    
    
    b.size()
    >>> torch.Size([4, 1, 4])
    
    c
    >>> tensor([[ 1.2912, -1.9050,  1.4771,  1.5517],
            [-0.3359, -0.2381, -0.3590,  0.0406],
            [-0.2460, -0.2326,  0.4511,  0.7255],
            [-0.1456, -0.0857, -0.8443,  1.1423]])
    
    
    c.size()
    >>> torch.Size([4, 4])
    
iacob
  • 14,010
  • 5
  • 54
  • 92
tomgtbst
  • 57
  • 4