1

I have a multi-label classification task in which I have 7 categories per each sample. I train a multi-label classification model (i.e., XLMRobertaForSequenceClassification). When I evaluate the model on a test set, I see one specific category has the highest score for most examples. So the confusion matrix has lots of zeros except for the column of the category with the highest score. I'm just wondering what can contribute to this. Is it because the training set has more samples with positive label from that specific category? Here's the statistics of the categories in the training set (in total, there are ~14 M samples in the training set):

cat1 : 112,803

cat2 : 859,448

cat3 : 155,382

cat4 : 13,816

cat5 : 34,242

cat6 : 36,104

cat7 : 1,206,626

And here's the average score predicted by the model across test sets for each category:

cat1: 0.03

cat2: 0.07

cat3: 0.04

cat4: 0.03

cat5: 0.07

cat6: 0.11

cat7: 0.29

My concern is that the average score of the categories does not align with the frequency of the samples in the training set. For instance, 'cat2' has the second highest frequency but its score is smaller than 'cat6' that has far fewer samples in the training set. Is there any other reason, in addition to the frequency of categories in the training set, contributing to the not balanced prediction? (Note that the test set has almost identical number of samples from each category).

A User
  • 13

1 Answers1

0

The exact mechanism will be hard to pin down, but you are correct to note that the generally high occurrence of that category drives that category to be predicted most of the time. Think of it this way: in the absence of compelling evidence, which category would you predict is most likely? You can formalize the math by considering the predicted probability to be a posterior probability that is calculated based on the overall probability, the “prior” I discuss here. (That link mentions a logistic regression as the model, but the idea applies to other models, too.)

An explanation for why you do not see similar behavior for category two could be that category two is harder to distinguish from the other categories (perhaps even that it’s hard to distinguish from category seven, so the model goes with the the generally more probable category).

Dave
  • 62,186
  • Somewhat tangential to the main question is that a “multi-label” problem is terminology typically used when several categorical outcomes are simultaneously possible (e.g., the picture contains a dog AND a cat, not just one type of animal). This appears to be what is typically called a “multi-class” problem. – Dave Aug 30 '23 at 01:56