core.units.mlip_unit.predict#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Base class for protocol classes. |
|
The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the |
|
Base class for protocol classes. |
Functions#
|
|
|
Create a mapping from dataset names to their associated tasks. |
|
Function to run server in separate process |
Module Contents#
- core.units.mlip_unit.predict.collate_predictions(predict_fn)#
- class core.units.mlip_unit.predict.MLIPPredictUnitProtocol#
Bases:
Protocol
Base class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto[T](Protocol): def meth(self) -> T: ...
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool) dict #
- property dataset_to_tasks: dict[str, list]#
- class core.units.mlip_unit.predict.MLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None)#
Bases:
torchtnt.framework.PredictUnit
[fairchem.core.datasets.atomic_data.AtomicData
],MLIPPredictUnitProtocol
The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the
predict_step
which is an abstract method where you can define the code you want to run each iteration of the dataloader.To use the PredictUnit, create a class which subclasses
PredictUnit
. Then implement thepredict_step
method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. In addition, you can overrideget_next_predict_batch
to modify the default batch fetching behavior. Below is a simple example of a user’s subclass ofPredictUnit
that implements a basicpredict_step
.from torchtnt.framework.unit import PredictUnit Batch = Tuple[torch.tensor, torch.tensor] # specify type of the data in each batch of the dataloader to allow for typechecking class MyPredictUnit(PredictUnit[Batch]): def __init__( self, module: torch.nn.Module, ): super().__init__() self.module = module def predict_step(self, state: State, data: Batch) -> torch.tensor: inputs, targets = data outputs = self.module(inputs) return outputs predict_unit = MyPredictUnit(module=...)
- atom_refs#
- tasks#
- _dataset_to_tasks#
- device#
- lazy_model_intialized = False#
- inference_mode#
- merged_on = None#
- property direct_forces: bool#
- property dataset_to_tasks: dict[str, list]#
- set_seed(seed: int)#
- move_to_device()#
- predict_step(state: torchtnt.framework.State, data: fairchem.core.datasets.atomic_data.AtomicData) dict[str, torch.tensor] #
Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with
@torch.inference_mode()
for improved performance.- Parameters:
state – a
State
object containing metadata about the prediction run.data – one batch of prediction data.
- get_composition_charge_spin_dataset(data)#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor] #
- core.units.mlip_unit.predict.get_dataset_to_tasks_map(tasks: Sequence[fairchem.core.units.mlip_unit.mlip_unit.Task]) dict[str, list[fairchem.core.units.mlip_unit.mlip_unit.Task]] #
Create a mapping from dataset names to their associated tasks.
- Parameters:
tasks – A sequence of Task objects to be organized by dataset
- Returns:
A dictionary mapping dataset names (str) to lists of Task objects that are associated with that dataset
- core.units.mlip_unit.predict._run_server_process(predictor_config, port, num_workers, ready_queue)#
Function to run server in separate process
- class core.units.mlip_unit.predict.ParallelMLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, server_config: dict | None = None, client_config: dict | None = None)#
Bases:
MLIPPredictUnitProtocol
Base class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto[T](Protocol): def meth(self) -> T: ...
- server_process = None#
- _dataset_to_tasks#
- _start_server_process(predict_unit_config, port, workers)#
Start server process and wait for it to be ready
- cleanup()#
- __del__()#
- predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor] #
Predict method that sends data to the remote server and returns predictions.
- property dataset_to_tasks: dict[str, list]#