3

I'm working on unsupervised learning techniques and I've been reading about the contrastive loss function. Specifically in this paper Momentum Contrast for Unsupervised Visual Representation Learning they describe the loss function mathematically as:

Contrastive

However when they show the pseudo code it's as follows:

# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK) 
# m: momentum
# t: temperature
f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
    x_q = aug(x) # a randomly augmented version
    x_k = aug(x) # another randomly augmented version
    q = f_q.forward(x_q) # queries: NxC 
    k = f_k.forward(x_k) # keys: NxC
    k = k.detach() # no gradient to keys

    # positive logits: Nx1
    l_pos = bmm(q.view(N,1,C), k.view(N,C,1))

    # negative logits: NxK
    l_neg = mm(q.view(N,C), queue.view(C,K))

    # logits: Nx(1+K)
    logits = cat([l_pos, l_neg], dim=1)

    # contrastive loss, Eqn.(1)
    labels = zeros(N) # positives are the 0-th
    loss = CrossEntropyLoss(logits/t, labels)

I don't understand how math and pseudo code are related. I get that the bmm and mm function are the dot project. Also that the l_pos is the top part of the term (query and single key value) while the l_neg is the bottom (query and values from the queue). Why are those 2 values concatenated, then compared with 0s? Any insight is greatly appreciated.

Brian
  • 141

1 Answers1

1

I've dug into this and I get what's going on now.

The big idea is that you can think of this just like classification. If you had K+1 categories, and a one hot of the true category, you'd want your prediction to have the true category be a large value and all the others be small.

If q and k+ vectors are close together, then the dot product (l_pos) will be large. The dot product of q and K (l_neg) will be small if they are far from each other. l_pos is our true class predicted probability and l_neg is for the other classes probabilities. So we then create a tensor by concatenating the l_pos and l_neg and give our label as the first item. CrossEntropyLoss then does the softmax and negative loss likelihood.

I further proved this to myself with an example:

def norm_tensor(x):
    x = np.array(x)
    x = x/np.linalg.norm(x)
    return torch.tensor(x)

t = 0.07
q = norm_tensor([[1.0, 0.01, 0.0]])
k_p = norm_tensor([[1.0, 0.02, 0.]])
k_neg = norm_tensor([[1, 1, 1]])

q -> [[1.0000, 0.0100, 0.0000]]

k_p -> [[0.9998, 0.0200, 0.0000]]

k_neg -> [[0.5774, 0.5774, 0.5774]]

pos = torch.mm(q, k_p.transpose(1,0))
neg = torch.mm(q, k_neg.transpose(1,0))
print(pos, neg)

pos -> [[1.0000]]

neg ->[[0.5831]]

logits = torch.cat((pos, neg), dim=1)
indexes = torch.tensor([0])
loss_func = torch.nn.CrossEntropyLoss()
loss_func(logits/t, indexes)

loss -> 0.0026

The loss is small as we expect since q and k+ are close, while q and K are farther away.

Brian
  • 141