core.models.escaip.modules.base_block#
Classes#
Base class for Graph Neural Network layers. |
Module Contents#
- class core.models.escaip.modules.base_block.BaseGraphNeuralNetworkLayer(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
Bases:
torch.nn.Module
Base class for Graph Neural Network layers. Used in InputLayer and EfficientGraphAttention.
- source_atomic_embedding#
- target_atomic_embedding#
- source_direction_embedding#
- target_direction_embedding#
- edge_distance_embedding#
- get_edge_linear(gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
- get_node_linear(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
- get_edge_features(x: fairchem.core.models.escaip.custom_types.GraphAttentionData) torch.Tensor #
- get_node_features(node_features: torch.Tensor, neighbor_list: torch.Tensor) torch.Tensor #
- aggregate(edge_features, neighbor_mask)#
- abstract forward()#