Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Batched Atomic Simulations with InferenceBatcher

When running many independent ASE calculations (relaxations, molecular dynamics, etc.) on small to medium-sized systems, you can significantly improve GPU utilization by batching model inference calls together. The InferenceBatcher class provides a high-level API to do this with minimal code changes.

The key idea is simple: instead of running each simulation sequentially, InferenceBatcher collects inference requests from multiple concurrent simulations and batches them together for more efficient GPU computation.

Basic Setup

To use InferenceBatcher, you need to:

  1. Create a predict unit as usual

  2. Wrap it with InferenceBatcher

  3. Use batcher.batch_predict_unit instead of the original predict unit in your simulation functions

from fairchem.core import pretrained_mlip
from fairchem.core.calculate import FAIRChemCalculator, InferenceBatcher

# Create a predict unit
predict_unit = pretrained_mlip.get_predict_unit("uma-s-1p1")

# Wrap it with InferenceBatcher
batcher = InferenceBatcher(
    predict_unit, concurrency_backend_options=dict(max_workers=32)
)

Writing Simulation Functions

The only requirement for using InferenceBatcher is to write your simulation logic as a function that takes an Atoms object and a predict unit as arguments:

from ase.build import bulk
from ase.filters import FrechetCellFilter
from ase.optimize import LBFGS


def run_relaxation(atoms, predict_unit):
    """Run a structure relaxation and return the final energy."""
    calc = FAIRChemCalculator(predict_unit, task_name="omat")
    atoms.calc = calc
    opt = LBFGS(FrechetCellFilter(atoms), logfile=None)
    opt.run(fmax=0.02, steps=100)
    return atoms.get_potential_energy()

Running Batched Relaxations

Once you have your simulation function, you can run it in batched mode using the executor’s map or submit methods:

Using executor.map

from functools import partial

# Create a list of structures to relax
prim_atoms = [
    bulk("Cu"),
    bulk("MgO", "rocksalt", a=4.2),
    bulk("Si", "diamond", a=5.43),
    bulk("NaCl", "rocksalt", a=3.8),
]

atoms_list = [make_supercell(atoms, 3 * np.identity(3)) for atoms in prim_atoms]

for atoms in atoms_list:
    atoms.rattle(0.1)

# Create a partial function with the batch predict unit
run_relaxation_batched = partial(
    run_relaxation, predict_unit=batcher.batch_predict_unit
)

# Run all relaxations in parallel with batched inference
relaxed_energies = list(batcher.executor.map(run_relaxation_batched, atoms_list))

Using executor.submit for more control

If you need more control over the execution or want to process results as they complete:

# Create a new list of structures to relax
atoms_list = [make_supercell(atoms, 3 * np.identity(3)) for atoms in prim_atoms]

for atoms in atoms_list:
    atoms.rattle(0.1)

# Submit all jobs
futures = [
    batcher.executor.submit(run_relaxation, atoms, batcher.batch_predict_unit)
    for atoms in atoms_list
]

# Collect results
relaxed_energies = [future.result() for future in futures]

Running Batched Molecular Dynamics

The same pattern works for molecular dynamics simulations:

from ase import units
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution


def run_nvt_md(atoms, predict_unit, temperature, traj_fname):
    """Run NVT molecular dynamics simulation."""
    calc = FAIRChemCalculator(predict_unit, task_name="omat")
    atoms.calc = calc
    MaxwellBoltzmannDistribution(atoms, temperature, force_temp=True)
    dyn = Langevin(
        atoms,
        timestep=2 * units.fs,
        temperature_K=temperature,
        friction=0.1,
        trajectory=traj_fname,
        loginterval=5,
    )
    dyn.run(100)


# Run batched MD simulations
run_md_batched = partial(
    run_nvt_md, predict_unit=batcher.batch_predict_unit, temperature=300
)

futures = [
    batcher.executor.submit(run_md_batched, atoms, traj_fname=f"traj_{i}.traj")
    for i, atoms in enumerate(atoms_list)
]

# Wait for all simulations to complete
[future.result() for future in futures]

When to Use InferenceBatcher

InferenceBatcher is most beneficial when: