core.units.mlip_unit.predict#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Attributes#
Classes#
Base class for protocol classes. |
|
The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the |
|
Base class for protocol classes. |
Functions#
|
|
|
Create a mapping from dataset names to their associated tasks. |
|
Recursively move all PyTorch tensors in a nested data structure to CPU. |
Module Contents#
- core.units.mlip_unit.predict.ray_installed = True#
- core.units.mlip_unit.predict.collate_predictions(predict_fn)#
- class core.units.mlip_unit.predict.MLIPPredictUnitProtocol#
Bases:
ProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto[T](Protocol): def meth(self) -> T: ...
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool) dict#
- property dataset_to_tasks: dict[str, list]#
- class core.units.mlip_unit.predict.MLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, assert_on_nans: bool = False)#
Bases:
torchtnt.framework.PredictUnit[fairchem.core.datasets.atomic_data.AtomicData],MLIPPredictUnitProtocolThe PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the
predict_stepwhich is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the PredictUnit, create a class which subclasses
PredictUnit. Then implement thepredict_stepmethod on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. In addition, you can overrideget_next_predict_batchto modify the default batch fetching behavior. Below is a simple example of a user’s subclass ofPredictUnitthat implements a basicpredict_step.from torchtnt.framework.unit import PredictUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyPredictUnit(PredictUnit[Batch]): def __init__( self, module: torch.nn.Module, ): super().__init__() self.module = module def predict_step(self, state: State, data: Batch) -> torch.tensor: inputs, targets = data outputs = self.module(inputs) return outputs predict_unit = MyPredictUnit(module=...)
- atom_refs#
- tasks#
- _dataset_to_tasks#
- device#
- lazy_model_intialized = False#
- inference_mode#
- merged_on = None#
- assert_on_nans#
- property direct_forces: bool#
- property dataset_to_tasks: dict[str, list]#
- set_seed(seed: int)#
- move_to_device()#
- predict_step(state: torchtnt.framework.State, data: fairchem.core.datasets.atomic_data.AtomicData) dict[str, torch.tensor]#
Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()for improved performance.- Parameters:
state – a
Stateobject containing metadata about the prediction run.data – one batch of prediction data.
- get_composition_charge_spin_dataset(data)#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor]#
- core.units.mlip_unit.predict.get_dataset_to_tasks_map(tasks: Sequence[fairchem.core.units.mlip_unit.mlip_unit.Task]) dict[str, list[fairchem.core.units.mlip_unit.mlip_unit.Task]]#
Create a mapping from dataset names to their associated tasks.
- Parameters:
tasks – A sequence of Task objects to be organized by dataset
- Returns:
A dictionary mapping dataset names (str) to lists of Task objects that are associated with that dataset
- core.units.mlip_unit.predict.move_tensors_to_cpu(data)#
Recursively move all PyTorch tensors in a nested data structure to CPU.
- Parameters:
data – Input data structure (dict, list, tuple, tensor, or other)
- Returns:
Data structure with all tensors moved to CPU
- class core.units.mlip_unit.predict.MLIPWorker(worker_id: int, world_size: int, predictor_config: dict, master_port: int | None = None, master_address: str | None = None)#
- worker_id#
- world_size#
- predictor_config#
- master_address#
- master_port#
- is_setup = False#
- get_master_address_and_port()#
- _distributed_setup(worker_id: int, master_port: int, world_size: int, predictor_config: dict, master_address: str)#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData) dict[str, torch.tensor] | None#
- class core.units.mlip_unit.predict.ParallelMLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, assert_on_nans: bool = False, num_workers: int = 1, num_workers_per_node: int = 8)#
Bases:
MLIPPredictUnitProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto[T](Protocol): def meth(self) -> T: ...
- _dataset_to_tasks#
- workers#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor]#
- property dataset_to_tasks: dict[str, list]#