0

I am searching for an equivalent of BestExporter in PyTorch to save the most recent N best checkpoints, but I cannot find one.

Here is my implementation:

class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model, optimizer
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),}, 
                f'weights/best_at_epoch_{epoch}_with_loss_{current_valid_loss}.pth')

It just saves all checkpoints each of which is better than its previous one in loss. I wonder if there is a high level API equivalent to BestExporter in TensorFlow?

Lerner Zhang
  • 5,154
  • 2
  • 38
  • 55

0 Answers0