2

I'm following the Bi-Encoder architecture (see here) in order to build a dense retrieval (search) system. Formally, my network encodes a query q and an item description d based on fixed representations from Sentence Transformers denoted as SBERT(q) and SBERT(d), respectively.

It then learns a transformation (the 'pooling' in the picture below) that maximizes the cosine similarity between positive examples (where the query and item description match) and minimize the similarity between negative examples (randomly assigned query/description pairs). I use an MSE loss.

Bi-Encoder schematic representation

Now, when I train my network, I observe that it (always) converges to producing a cosine similarity of 0.5 for all examples, provided that my labels are equally distributed as {0, 1}. If I adjust the balance of the positive/negative examples, it converges to whatever minimizes the MSE loss while still producing the same output (within a fractional range) for both positive and negative examples.

What could be going wrong? My dataset isn't the largest, only a few thousand examples. I would say that the queries are fairly semantically related to the descriptions, so it shouldn't be too hard to learn this mapping. The offline-computed sentence representations for the queries and descriptions also look reasonable. I have tried smaller and bigger networks for the pooling transformations, all with the same effect.

joko
  • 21
  • 2
  • 1
    Assuming there really is enough to discriminate positive and negative examples, cases like these are usually a problem with optimization. (1) What happens when you decrease the learning rate? (2) Does the 0.5 metric apply to training, validation, or both? (3) What does training loss look like over epochs? (4) In your data setup—"negative examples (randomly assigned query/description pairs)"—what is the chance that a negative pair is actually a positive? What happens if you enforce that the queries and descriptions do not match? – chicxulub May 12 '23 at 12:27
  • 1
    "It then learns a transformation (the 'pooling' in the picture below)" is not right. The pooling operation is usually a mean over token embeddings, so there's nothing to learn. It's the (tied) weights of BERT that are adjusted during training. – chicxulub May 12 '23 at 12:29
  • 1
    (5) Can you explain "The offline-computed sentence representations for the queries and descriptions also look reasonable." a bit more? Are these embeddings computed after or before training. And what makes you say they're reasonable—are you saying the precision/recall of your model on the query-description task is high? (IME, trained SBERT embeddings are more isotropic. So I'd be surprised if it's consistently predicting 0.5 similarity while precision/recall are high.) – chicxulub May 12 '23 at 12:35
  • 1
    In case it becomes clear that the training code cannot be improved, or if you're curious about simpler approaches, I'd recommend that you evaluate the zero-shot performance of an InstructOR model on your data and task. Here's an example of using an InstructOR model for a retrieval task. – chicxulub May 12 '23 at 12:48
  • Hi @chicxulub, thanks for the hints and sorry for the slow response -- was offline over the weekend. Your comment on the learning rate helped me! In fact, my LR was too low (at 1e-5), which I assume caused the model to end up in a local minimum (of assigning all inputs the same score). So I increased the LR to 2e-3 and the model started learning. Thanks! – joko May 16 '23 at 11:04
  • Regarding (some of) your other comments -- yes, the learning rate converged to 0.25 (0.5 squared because of MSE) for both train and validation loss. And regarding the transformation vs. pooling -- sorry, the image is misleading, I actually do have a small number of fully connected layers transforming the precomputed (and frozen) embeddings into lower dimensions. – joko May 16 '23 at 11:07
  • Thanks for the update @joko. I'm surprised to hear that the learning rate had to be increased, but such is the nature of training NNs. Glad to hear that have a working solution now! – chicxulub May 16 '23 at 18:11

1 Answers1

0

This was solved by increasing the learning rate. The model previously got stuck in a local minimum.

joko
  • 21
  • 2