0

I have a unet segmentation model, which outputs 5 classes, I would like to find the optimal threshold value for each class using the precision-recall curve:

def optimal_threshold_precision_recall_curve(Y_orig,y_pred):
for c in range(classes):

    precision, recall, thresholds = precision_recall_curve(Y_orig[:, c].ravel(), y_pred[:, c].ravel())

    #optimal thresholds - using Youden's J Statistic:
    optimal_thresholds = sorted(list(zip(np.abs(precision - recall), thresholds)), key=lambda i: i[0], reverse=False)[0][1]
    print("Youden - Ideal threshold is: ", optimal_thresholds)

when I pass 1 prediction (and its corresponding gt-mask) , I get the desired optimal thresholds for each class.

however when I pass another image, I get different thresholds values, and when I pass multiple predictions (on my validation set) I get something like average thresholds values, which are not the real optimal thresholds. How I can calculate the optimal thresholds for a validation set ?

0 Answers0