I am trying to create a transform that shuffles the patches of each image in a batch.
I aim to use it in the same manner as the rest of the transformations in torchvision:
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ShufflePatches(patch_size=(16,16)) # our new transform
])
More specifically, the input is a BxCxHxW tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.
Given the image (of size 224x224):
Using ShufflePatches(patch_size=(112,112)) I would like to produce the output image:
I think the solution has to do with torch.unfold and torch.fold, but didn't manage to get any further.
Any help would be appreciated!