core.units.mlip_unit.predict#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the |
Functions#
|
|
|
Create a mapping from dataset names to their associated tasks. |
Module Contents#
- core.units.mlip_unit.predict.collate_predictions(predict_fn)#
- class core.units.mlip_unit.predict.MLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41)#
Bases:
torchtnt.framework.PredictUnit
[fairchem.core.datasets.atomic_data.AtomicData
]The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the
predict_step
which is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the PredictUnit, create a class which subclasses
PredictUnit
. Then implement thepredict_step
method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. In addition, you can overrideget_next_predict_batch
to modify the default batch fetching behavior. Below is a simple example of a user’s subclass ofPredictUnit
that implements a basicpredict_step
.from torchtnt.framework.unit import PredictUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyPredictUnit(PredictUnit[Batch]): def __init__( self, module: torch.nn.Module, ): super().__init__() self.module = module def predict_step(self, state: State, data: Batch) -> torch.tensor: inputs, targets = data outputs = self.module(inputs) return outputs predict_unit = MyPredictUnit(module=...)
- tasks#
- dataset_to_tasks#
- device#
- lazy_model_intialized = False#
- inference_mode#
- merged_on = None#
- property direct_forces: bool#
- property datasets: list[str]#
- seed(seed: int)#
- move_to_device()#
- predict_step(state: torchtnt.framework.State, data: fairchem.core.datasets.atomic_data.AtomicData) dict[str, torch.tensor] #
Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.- Parameters:
state – a
State
object containing metadata about the prediction run.data – one batch of prediction data.
- get_composition_charge_spin_dataset(data)#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor] #
- core.units.mlip_unit.predict.get_dataset_to_tasks_map(tasks: Sequence[fairchem.core.units.mlip_unit.mlip_unit.Task]) dict[str, list[fairchem.core.units.mlip_unit.mlip_unit.Task]] #
Create a mapping from dataset names to their associated tasks.
- Parameters:
tasks – A sequence of Task objects to be organized by dataset
- Returns:
A dictionary mapping dataset names (str) to lists of Task objects that are associated with that dataset