core.datasets.atomic_data#
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
modified from troch_geometric Data class
Attributes#
Classes#
Functions#
|
|
|
Preforms nearest neighbor search and returns edge index, distances, |
|
Stack center and neighbor index and reshapes distances, |
|
all data points must be single graphs and have the same set of keys. |
|
Module Contents#
- core.datasets.atomic_data.IndexType#
- core.datasets.atomic_data._REQUIRED_KEYS = ['pos', 'atomic_numbers', 'cell', 'pbc', 'natoms', 'edge_index', 'cell_offsets', 'nedges',...#
- core.datasets.atomic_data._OPTIONAL_KEYS = ['energy', 'forces', 'stress']#
- core.datasets.atomic_data.size_repr(key: str, item: torch.Tensor, indent=0) str #
- core.datasets.atomic_data.get_neighbors_pymatgen(atoms: ase.Atoms, cutoff, max_neigh)#
Preforms nearest neighbor search and returns edge index, distances, and cell offsets
- core.datasets.atomic_data.reshape_features(c_index: numpy.ndarray, n_index: numpy.ndarray, n_distance: numpy.ndarray, offsets: numpy.ndarray)#
Stack center and neighbor index and reshapes distances, takes in np.arrays and returns torch tensors
- class core.datasets.atomic_data.AtomicData(pos, atomic_numbers, cell, pbc, natoms, edge_index, cell_offsets, nedges, charge, spin, fixed, tags, energy=None, forces=None, stress=None, batch=None, sid=None)#
- pos: torch.Tensor#
- atomic_numbers: torch.Tensor#
- cell: torch.Tensor#
- pbc: torch.Tensor#
- natoms: torch.Tensor#
- charge: torch.Tensor#
- spin: torch.Tensor#
- edge_index: torch.Tensor#
- cell_offsets: torch.Tensor#
- nedges: torch.Tensor#
- fixed: torch.Tensor#
- tags: torch.Tensor#
- energy: torch.Tensor#
- forces: torch.Tensor#
- stress: torch.Tensor#
- batch: torch.Tensor#
- sid: list[str]#
- __keys__#
- __slices__ = None#
- __cumsum__ = None#
- __cat_dims__ = None#
- __natoms_list__ = None#
- assign_batch_stats(slices, cumsum, cat_dims, natoms_list)#
- get_batch_stats()#
- validate()#
- classmethod from_ase(input_atoms: ase.Atoms, r_edges: bool = False, radius: float = 6.0, max_neigh: float | None = None, sid: str | None = None, molecule_cell_size: float | None = None, r_energy: bool = True, r_forces: bool = True, r_stress: bool = True, r_data_keys=None) AtomicData #
- to_ase_single() ase.Atoms #
- to_ase() list[ase.Atoms] #
- classmethod from_dict(dictionary)#
Creates a data object from a python dictionary.
- to_dict()#
- values()#
- property num_nodes: int#
Returns or sets the number of nodes in the graph.
- property num_edges: int#
Returns the number of edges in the graph.
- property num_graphs: int#
Returns the number of graphs in the batch.
- __len__()#
- get(key, default)#
- __getitem__(idx)#
- __setitem__(key: str, value: torch.Tensor)#
Sets the attribute
key
tovalue
.
- __setattr__(key: str, value: torch.Tensor)#
- __delitem__(key: str)#
Deletes the attribute
key
.
- keys()#
- __contains__(key)#
Returns
True
, if the attributekey
is present in the data.
- __iter__()#
Iterates over all present attributes in the data, yielding their attribute names and content.
- __call__(*keys)#
Iterates over all attributes
*keys
in the data, yielding their attribute names and content. If*keys
is not given this method will iterative over all present attributes.
- __cat_dim__(key, value) int #
Returns the dimension for which
value
of attributekey
will get concatenated when creating batches.Note
This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.
- __inc__(key, value) int #
Returns the incremental count to cumulatively increase the value of the next attribute of
key
when creating batches.Note
This method is for internal use only, and should only be overridden if the batch concatenation process is corrupted for a specific data attribute.
- __apply__(item, func)#
- apply(func)#
Applies the function
func
to all tensor attributes
- contiguous()#
Ensures a contiguous memory layout for all tensor attributes
- to(device, **kwargs)#
Performs tensor dtype and/or device conversion for all tensor attributes
- cpu()#
Copies all tensor attributes to CPU memory.
- cuda(device=None, non_blocking=False)#
Copies all tensor attributes to GPU memory.
- clone()#
Performs a deep-copy of the data object.
- __repr__()#
- get_example(idx: int) AtomicData #
Reconstructs the
AtomicData
object at indexidx
from a batched AtomicData object.
- index_select(idx: IndexType) list[AtomicData] #
- batch_to_atomicdata_list() list[AtomicData] #
Reconstructs the list of
torch_geometric.data.Data
objects from the batch object. The batch object must have been created viafrom_data_list()
in order to be able to reconstruct the initial objects.
- core.datasets.atomic_data.atomicdata_list_to_batch(data_list: list[AtomicData], exclude_keys: list | None = None) AtomicData #
all data points must be single graphs and have the same set of keys. TODO: exclude keys?
- core.datasets.atomic_data.tensor_or_int_to_tensor(x, dtype=torch.int)#