core.trainers#
Submodules#
Classes#
Helper class that provides a standard way to create an ABC using |
|
Trainer class for the Structure to Energy & Force (S2EF) and Initial State to |
Package Contents#
- class core.trainers.BaseTrainer(task: dict[str, str | Any], model: dict[str, Any], outputs: dict[str, str | int], dataset: dict[str, str | float], optimizer: dict[str, str | float], loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = 'wandb', amp: bool = False, cpu: bool = False, name: str = 'ocp', slurm=None, gp_gpus: int | None = None, inference_only: bool = False)#
Bases:
abc.ABC
Helper class that provides a standard way to create an ABC using inheritance.
- name#
- is_debug#
- cpu#
- epoch = 0#
- step = 0#
- timestamp_id: str#
- commit_hash#
- logger_name#
- config#
- scaler#
- elementrefs#
- normalizers#
- train_dataset = None#
- val_dataset = None#
- test_dataset = None#
- best_val_metric = None#
- primary_metric = None#
- abstract train(disable_eval_tqdm: bool = False) None #
Run model training iterations.
- static _get_timestamp(device: torch.device, suffix: str | None) str #
- load(inference_only: bool) None #
- static set_seed(seed) None #
- load_seed_from_config() None #
- load_logger() None #
- get_sampler(dataset, batch_size: int, shuffle: bool) fairchem.core.common.data_parallel.BalancedBatchSampler #
- get_dataloader(dataset, sampler) torch.utils.data.DataLoader #
- load_datasets() None #
- load_references_and_normalizers()#
Load or create element references and normalizers from config
- load_task()#
- load_model() None #
- property _unwrapped_model#
- load_checkpoint(checkpoint_path: str, checkpoint: dict | None = None, inference_only: bool = False) None #
- load_loss() None #
- load_optimizer() None #
- load_extras() None #
- save(metrics=None, checkpoint_file: str = 'checkpoint.pt', training_state: bool = True) str | None #
- update_best(primary_metric, val_metrics, disable_eval_tqdm: bool = True) None #
- _aggregate_metrics(metrics)#
- validate(split: str = 'val', disable_tqdm: bool = False)#
- _backward(loss) None #
- save_results(predictions: dict[str, numpy.typing.NDArray], results_file: str | None, keys: collections.abc.Sequence[str] | None = None) None #
- class core.trainers.OCPTrainer(task: dict[str, str | Any], model: dict[str, Any], outputs: dict[str, str | int], dataset: dict[str, str | float], optimizer: dict[str, str | float], loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = 'wandb', amp: bool = False, cpu: bool = False, name: str = 'ocp', slurm=None, gp_gpus: int | None = None, inference_only: bool = False)#
Bases:
fairchem.core.trainers.base_trainer.BaseTrainer
Trainer class for the Structure to Energy & Force (S2EF) and Initial State to Relaxed State (IS2RS) tasks.
Note
Examples of configurations for task, model, dataset and optimizer can be found in configs/ocp_s2ef and configs/ocp_is2rs.
- Parameters:
task (dict) – Task configuration.
model (dict) – Model configuration.
outputs (dict) – Output property configuration.
dataset (dict) – Dataset configuration. The dataset needs to be a SinglePointLMDB dataset.
optimizer (dict) – Optimizer configuration.
loss_functions (dict) – Loss function configuration.
evaluation_metrics (dict) – Evaluation metrics configuration.
identifier (str) – Experiment identifier that is appended to log directory.
run_dir (str, optional) – Path to the run directory where logs are to be saved. (default:
None
)is_debug (bool, optional) – Run in debug mode. (default:
False
)print_every (int, optional) – Frequency of printing logs. (default:
100
)seed (int, optional) – Random number seed. (default:
None
)logger (str, optional) – Type of logger to be used. (default:
wandb
)amp (bool, optional) – Run using automatic mixed precision. (default:
False
)slurm (dict) – Slurm configuration. Currently just for keeping track. (default:
{}
)
- train(disable_eval_tqdm: bool = False) None #
Run model training iterations.
- _denorm_preds(target_key: str, prediction: torch.Tensor, batch: torch_geometric.data.Batch)#
Convert model output from a batch into raw prediction by denormalizing and adding references
- _forward(batch)#
- _compute_loss(out, batch) torch.Tensor #
- _compute_metrics(out, batch, evaluator, metrics=None)#
- predict(data_loader, per_image: bool = True, results_file: str | None = None, disable_tqdm: bool = False)#
- run_relaxations(split='val')#