core.datasets.samplers.max_atom_distributed_sampler#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
A custom batch sampler that distributes batches across multiple GPUs to ensure efficient training. |
Functions#
|
Greedily creates batches from a list of samples with varying numbers of atoms. |
Module Contents#
- core.datasets.samplers.max_atom_distributed_sampler.get_batches(natoms_list: numpy.array, indices: numpy.array, max_atoms: int, min_atoms: int) tuple[list[list[int]], list[int], int] #
Greedily creates batches from a list of samples with varying numbers of atoms.
Args: natoms_list: Array of number of atoms in each sample. indices: Array of indices of the samples. max_atoms: Maximum number of atoms allowed in a batch.
Returns: tuple[list[list[int]], list[int], int]:
A tuple containing a list of batches, a list of the total number of atoms in each batch, and the number of samples that were filtered out because they exceeded the maximum number of atoms.
- class core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler(dataset: fairchem.core.datasets.base_dataset.BaseDataset, max_atoms: int, num_replicas: int, rank: int, seed: int, shuffle: bool = True, drop_last: bool = False, min_atoms: int = 0)#
Bases:
torch.utils.data.Sampler
[list
[int
]]A custom batch sampler that distributes batches across multiple GPUs to ensure efficient training.
Args: dataset (BaseDataset): The dataset to sample from. max_atoms (int): The maximum number of atoms allowed in a batch. num_replicas (int): The number of GPUs to distribute the batches across. rank (int): The rank of the current GPU. seed (int): The seed for shuffling the dataset. shuffle (bool): Whether to shuffle the dataset. Defaults to True. drop_last (bool): Whether to drop the last batch if its size is less than the maximum allowed size. Defaults to False.
This batch sampler is designed to work with the BaseDataset class and is optimized for distributed training. It takes into account the number of atoms in each sample and ensures that the batches are distributed evenly across GPUs.
- dataset#
- max_atoms#
- min_atoms#
- num_replicas#
- rank#
- seed#
- shuffle#
- drop_last#
- epoch = 0#
- start_iter = 0#
- all_batches#
- total_size#
- _prepare_batches() list[int] #
- __len__() int #
- __iter__() Iterator[list[int]] #
- set_epoch_and_start_iteration(epoch: int, start_iter: int) None #