core.common.data_parallel#
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
Classes#
More fine-grained state DataSampler that uses training iteration and epoch |
|
Wraps another sampler to yield a mini-batch of indices. |
Functions#
|
Greedily partition the given set by always inserting |
|
Module Contents#
- class core.common.data_parallel.OCPCollater(otf_graph: bool = False)#
- otf_graph#
- __call__(data_list: list[torch_geometric.data.Data]) torch_geometric.data.Batch #
- core.common.data_parallel._balanced_partition(sizes: numpy.typing.NDArray[numpy.int_], num_parts: int)#
Greedily partition the given set by always inserting the largest element into the smallest partition.
- class core.common.data_parallel.StatefulDistributedSampler(dataset, batch_size, **kwargs)#
Bases:
torch.utils.data.DistributedSampler
More fine-grained state DataSampler that uses training iteration and epoch both for shuffling data. PyTorch DistributedSampler only uses epoch for the shuffling and starts sampling data from the start. In case of training on very large data, we train for one epoch only and when we resume training, we want to resume the data sampler from the training iteration.
- start_iter = 0#
- batch_size#
- __iter__()#
- set_epoch_and_start_iteration(epoch, start_iter)#
- core.common.data_parallel._ensure_supported(dataset: Any)#
- class core.common.data_parallel.BalancedBatchSampler(dataset: torch.utils.data.Dataset, *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, mode: bool | Literal['atoms'] = 'atoms', shuffle: bool = True, on_error: Literal['warn_and_balance', 'warn_and_no_balance', 'raise'] = 'raise', drop_last: bool = False)#
Bases:
torch.utils.data.BatchSampler
Wraps another sampler to yield a mini-batch of indices.
- Parameters:
sampler (Sampler or Iterable) – Base sampler. Can be any iterable object
batch_size (int) – Size of mini-batch.
drop_last (bool) – If
True
, the sampler will drop the last batch if its size would be less thanbatch_size
Example
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- disabled = False#
- on_error#
- device#
- _get_natoms(batch_idx: list[int])#
- set_epoch_and_start_iteration(epoch: int, start_iteration: int) None #
- set_epoch(epoch: int) None #
- static _dist_enabled()#
- __iter__()#