core.models.escaip.modules.graph_attention_block#
Classes#
Efficient Graph Attention Block module. |
|
Efficient Graph Attention module. |
|
Feed Forward Network module. |
Module Contents#
- class core.models.escaip.modules.graph_attention_block.EfficientGraphAttentionBlock(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs, is_last: bool = False)#
Bases:
torch.nn.Module
Efficient Graph Attention Block module. Ref: swin transformer
- backbone_dtype#
- graph_attention#
- feedforward#
- norm_attn_node#
- norm_attn_edge#
- norm_ffn_node#
- stochastic_depth_attn#
- stochastic_depth_ffn#
- forward(data: fairchem.core.models.escaip.custom_types.GraphAttentionData, node_features: torch.Tensor, edge_features: torch.Tensor)#
- class core.models.escaip.modules.graph_attention_block.EfficientGraphAttention(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, molecular_graph_cfg: fairchem.core.models.escaip.configs.MolecularGraphConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs)#
Bases:
fairchem.core.models.escaip.modules.base_block.BaseGraphNeuralNetworkLayer
Efficient Graph Attention module.
- backbone_dtype#
- repeating_dimensions_list#
- rep_dim_len#
- use_frequency_embedding#
- edge_attr_linear#
- edge_attr_norm#
- message_norm#
- use_message_gate#
- attn_in_proj_q#
- attn_in_proj_k#
- attn_in_proj_v#
- attn_out_proj#
- attn_num_heads#
- attn_dropout#
- use_angle_embedding#
- use_graph_attention#
- forward(data: fairchem.core.models.escaip.custom_types.GraphAttentionData, node_features: torch.Tensor, edge_features: torch.Tensor)#
- multi_head_self_attention(input, attn_mask, frequency_vectors=None)#
- get_attn_bias(angle_embedding, edge_distance_expansion)#
- graph_attention_aggregate(edge_output, neighbor_mask)#
- class core.models.escaip.modules.graph_attention_block.FeedForwardNetwork(global_cfg: fairchem.core.models.escaip.configs.GlobalConfigs, gnn_cfg: fairchem.core.models.escaip.configs.GraphNeuralNetworksConfigs, reg_cfg: fairchem.core.models.escaip.configs.RegularizationConfigs, is_last: bool = False)#
Bases:
torch.nn.Module
Feed Forward Network module.
- backbone_dtype#
- mlp_node#
- forward(node_features: torch.Tensor, edge_features: torch.Tensor)#