core.components.train.train_runner#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

Checkpointable

Protocol that Units used by this trainer should implement if they want save and resume functionality

TrainCheckpointCallback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates

TrainEvalRunner

Represents an abstraction over things that run in a loop and can save/load state.

Functions#

Module Contents#

class core.components.train.train_runner.Checkpointable#

Bases: Protocol

Protocol that Units used by this trainer should implement if they want save and resume functionality This is in addition to Pytorch’s Stateful protocol because it allows units implement custom logic that’s required for checkpointing

save_state(checkpoint_location: str) None#

Save the unit state to a checkpoint path

Parameters:

checkpoint_location – The checkpoint path to save to

load_state(checkpoint_location: str | None) None#

Loads the state given a checkpoint path

Parameters:

checkpoint_location – The checkpoint path to restore from

core.components.train.train_runner.get_most_recent_viable_checkpoint_path(checkpoint_dir: str | None) str | None#
class core.components.train.train_runner.TrainCheckpointCallback(checkpoint_every_n_steps: int, max_saved_checkpoints: int = 2)#

Bases: torchtnt.framework.callback.Callback

A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.

Below is an example of a basic callback which prints a message at various points during execution.

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit

class PrintingCallback(Callback):
    def on_train_start(self, state: State, unit: TTrainUnit) -> None:
        print("Starting training")

    def on_train_end(self, state: State, unit: TTrainUnit) -> None:
        print("Ending training")

    def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
        print("Starting evaluation")

    def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
        print("Ending evaluation")

    def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
        print("Starting prediction")

    def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
        print("Ending prediction")

To use a callback, instantiate the class and pass it in the callbacks parameter to the train(), evaluate(), predict(), or fit() entry point.

printing_callback = PrintingCallback()
train(train_unit, train_dataloader, callbacks=[printing_callback])
checkpoint_every_n_steps#
max_saved_checkpoints#
save_callback = None#
load_callback = None#
checkpoint_dir = None#
set_runner_callbacks(save_callback: callable, load_callback: callable, checkpoint_dir: str) None#
on_train_step_start(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None#

Hook called before a new train step starts.

on_train_end(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None#

Hook called after training ends.

class core.components.train.train_runner.TrainEvalRunner(train_dataloader: torch.utils.data.dataloader, eval_dataloader: torch.utils.data.dataloader, train_eval_unit: torchtnt.framework.TrainUnit | torchtnt.framework.EvalUnit | Checkpointable, callbacks: list[torchtnt.framework.callback.Callback] | None = None, max_epochs: int | None = 1, evaluate_every_n_steps: int | None = None, max_steps: int | None = None)#

Bases: fairchem.core.components.runner.Runner

Represents an abstraction over things that run in a loop and can save/load state.

ie: Trainers, Validators, Relaxation all fall in this category.

Note

When running with the fairchemv2 cli, the job_config and attribute is set at runtime to those given in the config file.

job_config#

a managed attribute that gives access to the job config

Type:

DictConfig

train_dataloader#
eval_dataloader#
train_eval_unit#
callbacks#
max_epochs#
max_steps#
evaluate_every_n_steps#
checkpoint_callback#
run() None#
save_state(checkpoint_location: str, is_preemption: bool = False) bool#
load_state(checkpoint_location: str | None) None#