I am searching for an equivalent of BestExporter in PyTorch to save the most recent N best checkpoints, but I cannot find one.
Here is my implementation:
class SaveBestModel:
"""
Class to save the best model while training. If the current epoch's
validation loss is less than the previous least less, then save the
model state.
"""
def __init__(
self, best_valid_loss=float('inf')
):
self.best_valid_loss = best_valid_loss
def __call__(
self, current_valid_loss,
epoch, model, optimizer
):
if current_valid_loss < self.best_valid_loss:
self.best_valid_loss = current_valid_loss
print(f"\nBest validation loss: {self.best_valid_loss}")
print(f"\nSaving best model for epoch: {epoch+1}\n")
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),},
f'weights/best_at_epoch_{epoch}_with_loss_{current_valid_loss}.pth')
It just saves all checkpoints each of which is better than its previous one in loss. I wonder if there is a high level API equivalent to BestExporter in TensorFlow?