core.models.equiformer_v2.prediction_heads#
Submodules#
Classes#
A rank 2 symmetric tensor prediction head. |
Package Contents#
- class core.models.equiformer_v2.prediction_heads.Rank2SymmetricTensorHead(backbone: fairchem.core.models.base.BackboneInterface, output_name: str, decompose: bool = False, edge_level_mlp: bool = False, num_mlp_layers: int = 2, use_source_target_embedding: bool = False, extensive: bool = False, avg_num_nodes: int = 1.0, default_norm_type: str = 'layer_norm_sh')#
Bases:
torch.nn.Module
,fairchem.core.models.base.HeadInterface
A rank 2 symmetric tensor prediction head.
- ouput_name#
name of output prediction property (ie, stress)
- sphharm_norm#
layer normalization for spherical harmonic edge weights
- xedge_layer_norm#
embedding layer norm
- block#
rank 2 equivariant symmetric tensor block
- output_name#
- decompose#
- use_source_target_embedding#
- avg_num_nodes#
- sphharm_norm#
- xedge_layer_norm#
- forward(data: dict[str, torch.Tensor] | torch.Tensor, emb: dict[str, torch.Tensor]) dict[str, torch.Tensor] #
- Parameters:
data – data batch
emb – dictionary with embedding object and graph data
Returns: dict of {output property name: predicted value}