core.units.mlip_unit._batch_serve#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
Ray Serve deployment that batches incoming inference requests. |
Functions#
Set up and deploy a BatchPredictServer for batched inference. |
Module Contents#
- class core.units.mlip_unit._batch_serve.BatchPredictServer(predict_unit_ref, max_batch_size: int, batch_wait_timeout_s: float, split_oom_batch: bool = True)#
Ray Serve deployment that batches incoming inference requests.
- predict_unit#
- split_oom_batch#
- configure_batching(max_batch_size: int = 32, batch_wait_timeout_s: float = 0.05)#
- get_predict_unit_attribute(attribute_name: str) Any#
- async predict(data_list: list[fairchem.core.datasets.atomic_data.AtomicData], undo_element_references: bool = True) list[dict]#
Process a batch of AtomicData objects.
- Parameters:
data_list – List of AtomicData objects (automatically batched by Ray Serve)
- Returns:
List of prediction dictionaries, one per input
- async __call__(data: fairchem.core.datasets.atomic_data.AtomicData, undo_element_references: bool = True) dict#
Main entry point for inference requests.
- Parameters:
data – Single AtomicData object
- Returns:
Prediction dictionary for this system
- _split_predictions(predictions: dict, batch: fairchem.core.datasets.atomic_data.AtomicData) list[dict]#
Split batched predictions back into individual system predictions.
- Parameters:
batch_predictions – Dictionary of batched prediction tensors
batch – The batched AtomicData used for inference
- Returns:
List of prediction dictionaries, one per system
- core.units.mlip_unit._batch_serve.setup_batch_predict_server(predict_unit: fairchem.core.units.mlip_unit.MLIPPredictUnit, max_batch_size: int = 32, batch_wait_timeout_s: float = 0.1, split_oom_batch: bool = True, num_replicas: int = 1, ray_actor_options: dict | None = None, deployment_name: str = 'predict-server', route_prefix: str = '/predict') ray.serve.handle.DeploymentHandle#
Set up and deploy a BatchPredictServer for batched inference.
- Parameters:
predict_unit – An MLIPPredictUnit instance to use for batched inference
max_batch_size – Maximum number of systems per batch.
batch_wait_timeout_s – Maximum wait time before processing partial batch.
split_oom_batch – Whether to split batches that cause OOM errors.
num_replicas – Number of deployment replicas for scaling.
ray_actor_options – Additional Ray actor options (e.g., {“num_gpus”: 1, “num_cpus”: 4})
deployment_name – Name for the Ray Serve deployment.
route_prefix – HTTP route prefix for the deployment.
- Returns:
Ray Serve deployment handle that can be used to initialize BatchServerPredictUnit