I'm working on unsupervised learning techniques and I've been reading about the contrastive loss function. Specifically in this paper Momentum Contrast for Unsupervised Visual Representation Learning they describe the loss function mathematically as:
However when they show the pseudo code it's as follows:
# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature
f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version
q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC
k = k.detach() # no gradient to keys
# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)
I don't understand how math and pseudo code are related. I get that the bmm and mm function are the dot project. Also that the l_pos is the top part of the term (query and single key value) while the l_neg is the bottom (query and values from the queue). Why are those 2 values concatenated, then compared with 0s? Any insight is greatly appreciated.
