core.units.mlip_unit._batch_serve#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

BatchPredictServer

Ray Serve deployment that batches incoming inference requests.

Functions#

setup_batch_predict_server(...)

Set up and deploy a BatchPredictServer for batched inference.

Module Contents#

class core.units.mlip_unit._batch_serve.BatchPredictServer(predict_unit_ref, max_batch_size: int, batch_wait_timeout_s: float, split_oom_batch: bool = True)#

Ray Serve deployment that batches incoming inference requests.

predict_unit#
split_oom_batch#
configure_batching(max_batch_size: int = 32, batch_wait_timeout_s: float = 0.05)#
get_predict_unit_attribute(attribute_name: str) Any#
async predict(data_list: list[fairchem.core.datasets.atomic_data.AtomicData], undo_element_references: bool = True) list[dict]#

Process a batch of AtomicData objects.

Parameters:

data_list – List of AtomicData objects (automatically batched by Ray Serve)

Returns:

List of prediction dictionaries, one per input

async __call__(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict#

Main entry point for inference requests.

Parameters:

data – Single AtomicData object

Returns:

Prediction dictionary for this system

_split_predictions(predictions: dict, batch: fairchem.core.datasets.atomic_data.AtomicData) list[dict]#

Split batched predictions back into individual system predictions.

Parameters:
  • batch_predictions – Dictionary of batched prediction tensors

  • batch – The batched AtomicData used for inference

Returns:

List of prediction dictionaries, one per system

core.units.mlip_unit._batch_serve.setup_batch_predict_server(predict_unit: fairchem.core.units.mlip_unit.MLIPPredictUnit, max_batch_size: int = 32, batch_wait_timeout_s: float = 0.1, split_oom_batch: bool = True, num_replicas: int = 1, ray_actor_options: dict | None = None, deployment_name: str = 'predict-server', route_prefix: str = '/predict') ray.serve.handle.DeploymentHandle#

Set up and deploy a BatchPredictServer for batched inference.

Parameters:
  • predict_unit – An MLIPPredictUnit instance to use for batched inference

  • max_batch_size – Maximum number of systems per batch.

  • batch_wait_timeout_s – Maximum wait time before processing partial batch.

  • split_oom_batch – Whether to split batches that cause OOM errors.

  • num_replicas – Number of deployment replicas for scaling.

  • ray_actor_options – Additional Ray actor options (e.g., {“num_gpus”: 1, “num_cpus”: 4})

  • deployment_name – Name for the Ray Serve deployment.

  • route_prefix – HTTP route prefix for the deployment.

Returns:

Ray Serve deployment handle that can be used to initialize BatchServerPredictUnit