core.units.mlip_unit.predict#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

MLIPPredictUnitProtocol

Base class for protocol classes.

MLIPPredictUnit

The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the predict_step which

ParallelMLIPPredictUnit

Base class for protocol classes.

Functions#

collate_predictions(predict_fn)

get_dataset_to_tasks_map(→ dict[str, ...)

Create a mapping from dataset names to their associated tasks.

_run_server_process(predictor_config, port, ...)

Function to run server in separate process

Module Contents#

core.units.mlip_unit.predict.collate_predictions(predict_fn)#
class core.units.mlip_unit.predict.MLIPPredictUnitProtocol#

Bases: Protocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool) dict#
property dataset_to_tasks: dict[str, list]#
class core.units.mlip_unit.predict.MLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None)#

Bases: torchtnt.framework.PredictUnit[fairchem.core.datasets.atomic_data.AtomicData], MLIPPredictUnitProtocol

The PredictUnit is an interface that can be used to organize your prediction logic. The core of it is the predict_step which is an abstract method where you can define the code you want to run each iteration of the dataloader.

To use the PredictUnit, create a class which subclasses PredictUnit. Then implement the predict_step method on your class, and then you can optionally implement any of the hooks which allow you to control the behavior of the loop at different points. In addition, you can override get_next_predict_batch to modify the default batch fetching behavior. Below is a simple example of a user’s subclass of PredictUnit that implements a basic predict_step.

from torchtnt.framework.unit import PredictUnit

Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking

class MyPredictUnit(PredictUnit[Batch]):
    def __init__(
        self,
        module: torch.nn.Module,
    ):
        super().__init__()
        self.module = module

    def predict_step(self, state: State, data: Batch) -> torch.tensor:
        inputs, targets = data
        outputs = self.module(inputs)
        return outputs

predict_unit = MyPredictUnit(module=...)
atom_refs#
tasks#
_dataset_to_tasks#
device#
lazy_model_intialized = False#
inference_mode#
merged_on = None#
property direct_forces: bool#
property dataset_to_tasks: dict[str, list]#
set_seed(seed: int)#
move_to_device()#
predict_step(state: torchtnt.framework.State, data: fairchem.core.datasets.atomic_data.AtomicData) dict[str, torch.tensor]#

Core required method for user to implement. This method will be called at each iteration of the predict dataloader, and can return any data the user wishes. Optionally can be decorated with @torch.inference_mode() for improved performance.

Parameters:
  • state – a State object containing metadata about the prediction run.

  • data – one batch of prediction data.

get_composition_charge_spin_dataset(data)#
predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor]#
core.units.mlip_unit.predict.get_dataset_to_tasks_map(tasks: Sequence[fairchem.core.units.mlip_unit.mlip_unit.Task]) dict[str, list[fairchem.core.units.mlip_unit.mlip_unit.Task]]#

Create a mapping from dataset names to their associated tasks.

Parameters:

tasks – A sequence of Task objects to be organized by dataset

Returns:

A dictionary mapping dataset names (str) to lists of Task objects that are associated with that dataset

core.units.mlip_unit.predict._run_server_process(predictor_config, port, num_workers, ready_queue)#

Function to run server in separate process

class core.units.mlip_unit.predict.ParallelMLIPPredictUnit(inference_model_path: str, device: str = 'cpu', overrides: dict | None = None, inference_settings: fairchem.core.units.mlip_unit.InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, server_config: dict | None = None, client_config: dict | None = None)#

Bases: MLIPPredictUnitProtocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
server_process = None#
_dataset_to_tasks#
_start_server_process(predict_unit_config, port, workers)#

Start server process and wait for it to be ready

cleanup()#
__del__()#
predict(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict[str, torch.tensor]#

Predict method that sends data to the remote server and returns predictions.

property dataset_to_tasks: dict[str, list]#