core.trainers.base_trainer#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Helper class that provides a standard way to create an ABC using |
Module Contents#
- class core.trainers.base_trainer.BaseTrainer(task: dict[str, str | Any], model: dict[str, Any], outputs: dict[str, str | int], dataset: dict[str, str | float], optimizer: dict[str, str | float], loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = 'wandb', amp: bool = False, cpu: bool = False, name: str = 'ocp', slurm=None, gp_gpus: int | None = None, inference_only: bool = False)#
Bases:
abc.ABC
Helper class that provides a standard way to create an ABC using inheritance.
- name#
- is_debug#
- cpu#
- epoch = 0#
- step = 0#
- timestamp_id: str#
- config#
- scaler#
- elementrefs#
- normalizers#
- train_dataset = None#
- val_dataset = None#
- test_dataset = None#
- best_val_metric = None#
- primary_metric = None#
- abstract train(disable_eval_tqdm: bool = False) None #
Run model training iterations.
- static _get_timestamp(device: torch.device, suffix: str | None) str #
- load(inference_only: bool) None #
- static set_seed(seed) None #
- load_seed_from_config() None #
- load_logger() None #
- get_sampler(dataset, batch_size: int, shuffle: bool) fairchem.core.common.data_parallel.BalancedBatchSampler #
- get_dataloader(dataset, sampler) torch.utils.data.DataLoader #
- load_datasets() None #
- load_references_and_normalizers()#
Load or create element references and normalizers from config
- load_task()#
- load_model() None #
- property _unwrapped_model#
- load_checkpoint(checkpoint_path: str, checkpoint: dict | None = None, inference_only: bool = False) None #
- load_loss() None #
- load_optimizer() None #
- load_extras() None #
- save(metrics=None, checkpoint_file: str = 'checkpoint.pt', training_state: bool = True) str | None #
- update_best(primary_metric, val_metrics, disable_eval_tqdm: bool = True) None #
- _aggregate_metrics(metrics)#
- validate(split: str = 'val', disable_tqdm: bool = False)#
- _backward(loss) None #
- save_results(predictions: dict[str, numpy.typing.NDArray], results_file: str | None, keys: collections.abc.Sequence[str] | None = None) None #