core.common.data_parallel#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Classes#

OCPCollater

StatefulDistributedSampler

More fine-grained state DataSampler that uses training iteration and epoch

BalancedBatchSampler

Wraps another sampler to yield a mini-batch of indices.

Functions#

_balanced_partition(sizes, num_parts)

Greedily partition the given set by always inserting

_ensure_supported(dataset)

Module Contents#

class core.common.data_parallel.OCPCollater(otf_graph: bool = False)#
otf_graph#
__call__(data_list: list[torch_geometric.data.Data]) torch_geometric.data.Batch#
core.common.data_parallel._balanced_partition(sizes: numpy.typing.NDArray[numpy.int_], num_parts: int)#

Greedily partition the given set by always inserting the largest element into the smallest partition.

class core.common.data_parallel.StatefulDistributedSampler(dataset, batch_size, **kwargs)#

Bases: torch.utils.data.DistributedSampler

More fine-grained state DataSampler that uses training iteration and epoch both for shuffling data. PyTorch DistributedSampler only uses epoch for the shuffling and starts sampling data from the start. In case of training on very large data, we train for one epoch only and when we resume training, we want to resume the data sampler from the training iteration.

start_iter = 0#
batch_size#
__iter__()#
set_epoch_and_start_iteration(epoch, start_iter)#
core.common.data_parallel._ensure_supported(dataset: Any)#
class core.common.data_parallel.BalancedBatchSampler(dataset: torch.utils.data.Dataset, *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, mode: bool | Literal['atoms'] = 'atoms', shuffle: bool = True, on_error: Literal['warn_and_balance', 'warn_and_no_balance', 'raise'] = 'raise', drop_last: bool = False)#

Bases: torch.utils.data.BatchSampler

Wraps another sampler to yield a mini-batch of indices.

Parameters:
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object

  • batch_size (int) – Size of mini-batch.

  • drop_last (bool) – If True, the sampler will drop the last batch if its size would be less than batch_size

Example

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
disabled = False#
on_error#
sampler#
device#
_get_natoms(batch_idx: list[int])#
set_epoch_and_start_iteration(epoch: int, start_iteration: int) None#
set_epoch(epoch: int) None#
static _dist_enabled()#
__iter__()#