core.components.train.train_runner#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates |
|
Represents an abstraction over things that run in a loop and can save/load state. |
Functions#
|
Module Contents#
- core.components.train.train_runner.get_most_recent_viable_checkpoint_path(checkpoint_dir: str | None) str | None #
- class core.components.train.train_runner.TrainCheckpointCallback(checkpoint_every_n_steps: int, max_saved_checkpoints: int = 2)#
Bases:
torchtnt.framework.callback.Callback
A Callback is an optional extension that can be used to supplement your loop with additional functionality. Good candidates for such logic are ones that can be re-used across units. Callbacks are generally not intended for modeling code; this should go in your Unit. To write your own callback, subclass the Callback class and add your own code into the hooks.
Below is an example of a basic callback which prints a message at various points during execution.
from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit class PrintingCallback(Callback): def on_train_start(self, state: State, unit: TTrainUnit) -> None: print("Starting training") def on_train_end(self, state: State, unit: TTrainUnit) -> None: print("Ending training") def on_eval_start(self, state: State, unit: TEvalUnit) -> None: print("Starting evaluation") def on_eval_end(self, state: State, unit: TEvalUnit) -> None: print("Ending evaluation") def on_predict_start(self, state: State, unit: TPredictUnit) -> None: print("Starting prediction") def on_predict_end(self, state: State, unit: TPredictUnit) -> None: print("Ending prediction")
To use a callback, instantiate the class and pass it in the
callbacks
parameter to thetrain()
,evaluate()
,predict()
, orfit()
entry point.printing_callback = PrintingCallback() train(train_unit, train_dataloader, callbacks=[printing_callback])
- checkpoint_every_n_steps#
- max_saved_checkpoints#
- save_callback = None#
- load_callback = None#
- checkpoint_dir = None#
- set_runner_callbacks(save_callback: callable, load_callback: callable, checkpoint_dir: str) None #
- on_train_step_start(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None #
Hook called before a new train step starts.
- on_train_end(state: torchtnt.framework.state.State, unit: torchtnt.framework.unit.TTrainUnit) None #
Hook called after training ends.
- class core.components.train.train_runner.TrainEvalRunner(train_dataloader: torch.utils.data.dataloader, eval_dataloader: torch.utils.data.dataloader, train_eval_unit: torchtnt.framework.TrainUnit | torchtnt.framework.EvalUnit | torch.distributed.checkpoint.stateful.Stateful, callbacks: list[torchtnt.framework.callback.Callback] | None = None, max_epochs: int | None = 1, evaluate_every_n_steps: int | None = None, max_steps: int | None = None, save_inference_ckpt: bool = True)#
Bases:
fairchem.core.components.runner.Runner
Represents an abstraction over things that run in a loop and can save/load state.
ie: Trainers, Validators, Relaxation all fall in this category.
Note
When running with the fairchemv2 cli, the job_config and attribute is set at runtime to those given in the config file.
- job_config#
a managed attribute that gives access to the job config
- Type:
DictConfig
- train_dataloader#
- eval_dataloader#
- train_eval_unit#
- callbacks#
- max_epochs#
- max_steps#
- evaluate_every_n_steps#
- save_inference_ckpt#
- checkpoint_callback#
- run() None #
- save_state(checkpoint_location: str, is_preemption: bool = False) bool #
- load_state(checkpoint_location: str | None) None #