core.datasets.atomic_data#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

modified from troch_geometric Data class

Attributes#

Classes#

Functions#

size_repr(→ str)

get_neighbors_pymatgen(atoms, cutoff, max_neigh)

Preforms nearest neighbor search and returns edge index, distances,

reshape_features(c_index, n_index, n_distance, offsets)

Stack center and neighbor index and reshapes distances,

atomicdata_list_to_batch(→ AtomicData)

all data points must be single graphs and have the same set of keys.

tensor_or_int_to_tensor(x[, dtype])

Module Contents#

core.datasets.atomic_data.IndexType#
core.datasets.atomic_data._REQUIRED_KEYS = ['pos', 'atomic_numbers', 'cell', 'pbc', 'natoms', 'edge_index', 'cell_offsets', 'nedges',...#
core.datasets.atomic_data._OPTIONAL_KEYS = ['energy', 'forces', 'stress']#
core.datasets.atomic_data.size_repr(key: str, item: torch.Tensor, indent=0) str#
core.datasets.atomic_data.get_neighbors_pymatgen(atoms: ase.Atoms, cutoff, max_neigh)#

Preforms nearest neighbor search and returns edge index, distances, and cell offsets

core.datasets.atomic_data.reshape_features(c_index: numpy.ndarray, n_index: numpy.ndarray, n_distance: numpy.ndarray, offsets: numpy.ndarray)#

Stack center and neighbor index and reshapes distances, takes in np.arrays and returns torch tensors

class core.datasets.atomic_data.AtomicData(pos, atomic_numbers, cell, pbc, natoms, edge_index, cell_offsets, nedges, charge, spin, fixed, tags, energy=None, forces=None, stress=None, batch=None, sid=None)#
pos: torch.Tensor#
atomic_numbers: torch.Tensor#
cell: torch.Tensor#
pbc: torch.Tensor#
natoms: torch.Tensor#
charge: torch.Tensor#
spin: torch.Tensor#
edge_index: torch.Tensor#
cell_offsets: torch.Tensor#
nedges: torch.Tensor#
fixed: torch.Tensor#
tags: torch.Tensor#
energy: torch.Tensor#
forces: torch.Tensor#
stress: torch.Tensor#
batch: torch.Tensor#
sid: list[str]#
__keys__#
__slices__ = None#
__cumsum__ = None#
__cat_dims__ = None#
__natoms_list__ = None#
assign_batch_stats(slices, cumsum, cat_dims, natoms_list)#
get_batch_stats()#
validate()#
classmethod from_ase(input_atoms: ase.Atoms, r_edges: bool = False, radius: float = 6.0, max_neigh: float | None = None, sid: str | None = None, molecule_cell_size: float | None = None, r_energy: bool = True, r_forces: bool = True, r_stress: bool = True, r_data_keys=None) AtomicData#
to_ase_single() ase.Atoms#
to_ase() list[ase.Atoms]#
classmethod from_dict(dictionary)#

Creates a data object from a python dictionary.

to_dict()#
values()#
property num_nodes: int#

Returns or sets the number of nodes in the graph.

property num_edges: int#

Returns the number of edges in the graph.

property num_graphs: int#

Returns the number of graphs in the batch.

__len__()#
get(key, default)#
__getitem__(idx)#
__setitem__(key: str, value: torch.Tensor)#

Sets the attribute key to value.

__setattr__(key: str, value: torch.Tensor)#
__delitem__(key: str)#

Deletes the attribute key.

keys()#
__contains__(key)#

Returns True, if the attribute key is present in the data.

__iter__()#

Iterates over all present attributes in the data, yielding their attribute names and content.

__call__(*keys)#

Iterates over all attributes *keys in the data, yielding their attribute names and content. If *keys is not given this method will iterative over all present attributes.

__cat_dim__(key, value) int#

Returns the dimension for which value of attribute key will get concatenated when creating batches.

Note

This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.

__inc__(key, value) int#

Returns the incremental count to cumulatively increase the value of the next attribute of key when creating batches.

Note

This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.

__apply__(item, func)#
apply(func)#

Applies the function func to all tensor attributes

contiguous()#

Ensures a contiguous memory layout for all tensor attributes

to(device, **kwargs)#

Performs tensor dtype and/or device conversion for all tensor attributes

cpu()#

Copies all tensor attributes to CPU memory.

cuda(device=None, non_blocking=False)#

Copies all tensor attributes to GPU memory.

clone()#

Performs a deep-copy of the data object.

__repr__()#
get_example(idx: int) AtomicData#

Reconstructs the AtomicData object at index idx from a batched AtomicData object.

index_select(idx: IndexType) list[AtomicData]#
batch_to_atomicdata_list() list[AtomicData]#

Reconstructs the list of torch_geometric.data.Data objects from the batch object. The batch object must have been created via from_data_list() in order to be able to reconstruct the initial objects.

core.datasets.atomic_data.atomicdata_list_to_batch(data_list: list[AtomicData], exclude_keys: list | None = None) AtomicData#

all data points must be single graphs and have the same set of keys. TODO: exclude keys?

core.datasets.atomic_data.tensor_or_int_to_tensor(x, dtype=torch.int)#