def restart_from_checkpoint(checkpoint_path, restore_objects=None, **kwargs): """ Re-start training or inference from a previous checkpoint. Args: checkpoint_path (str): Path to checkpoint file restore_objects (dict): Dict containing objects to reload from checkpoint **kwargs (dict): Keyword args containing model states to reload Returns: None Example: # run once to create checkpoint; run again to load checkpoint import torch model = torch.nn.Linear(10, 5) optimizer = torch.optim.Adam(model.parameters()) num_epochs = 10 to_restore = {"epoch": 0} # if the checkpoint does not exist, this is a no-op restart_from_checkpoint( "checkpoint.pth", restore_objects=to_restore, model=model, optimizer=optimizer ) start_epoch = to_restore["epoch"] for epoch in range(start_epoch, num_epochs): # load data, move to GPU, pass through model, calculate loss, step optimizer, etc. checkpoint = { "epoch": epoch + 1, "model": model.state_dict(), "optimizer": optimizer.state_dict(), } torch.save(checkpoint, "checkpoint.pth") """ if checkpoint_path is None or not os.path.isfile(checkpoint_path): return logger.info(f"Found checkpoint at {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") # load states from checkpoint for key, model in kwargs.items(): if key in checkpoint and model is not None: try: msg = model.load_state_dict(checkpoint[key], strict=False) logger.info( f"Loaded '{key}' from checkpoint '{checkpoint_path}' with msg {msg}" ) except TypeError: msg = model.load_state_dict(checkpoint[key]) logger.info(f"Loaded '{key}' from checkpoint '{checkpoint_path}'") except ValueError: logger.warn( f"Failed to load '{key}' from checkpoint '{checkpoint_path}'" ) else: logger.info(f"Key '{key}' not found in checkpoint '{checkpoint_path}'") # reload important variables if restore_objects is not None: for var_name in restore_objects: if var_name in checkpoint: restore_objects[var_name] = checkpoint[var_name] logger.info(f"Loaded '{var_name}' from checkpoint '{checkpoint_path}'")