core.units.mlip_unit._metrics#

Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Attributes#

Classes#

Functions#

metrics_dict(→ Callable)

Wrap up the return of a metrics function

cosine_similarity(prediction, target[, key])

mae(→ torch.Tensor)

mse(→ torch.Tensor)

rmse(→ torch.Tensor)

per_atom_mae(→ torch.Tensor)

per_atom_mse(→ torch.Tensor)

magnitude_error(→ torch.Tensor)

forcesx_mae(→ Metrics)

forcesx_mse(→ Metrics)

forcesy_mae(→ Metrics)

forcesy_mse(→ Metrics)

forcesz_mae(→ Metrics)

forcesz_mse(→ Metrics)

energy_forces_within_threshold(→ Metrics)

energy_within_threshold(→ Metrics)

average_distance_within_threshold(→ Metrics)

min_diff(pred_pos, dft_pos, cell, pbc)

Calculate the minimum difference between predicted and target positions considering periodic boundary conditions.

get_metrics_fn(→ Callable)

Module Contents#

core.units.mlip_unit._metrics.NONE_SLICE#
class core.units.mlip_unit._metrics.Metrics#
metric: float = 0.0#
total: float = 0.0#
numel: int = 0#
__iadd__(other)#
core.units.mlip_unit._metrics.metrics_dict(metric_fun: Callable) Callable#

Wrap up the return of a metrics function

core.units.mlip_unit._metrics.cosine_similarity(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE)#
core.units.mlip_unit._metrics.mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.units.mlip_unit._metrics.mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.units.mlip_unit._metrics.rmse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) torch.Tensor#
core.units.mlip_unit._metrics.per_atom_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.units.mlip_unit._metrics.per_atom_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.units.mlip_unit._metrics.magnitude_error(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE, p: int = 2) torch.Tensor#
core.units.mlip_unit._metrics.forcesx_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) Metrics#
core.units.mlip_unit._metrics.forcesx_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) Metrics#
core.units.mlip_unit._metrics.forcesy_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.forcesy_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.forcesz_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.forcesz_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.energy_forces_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.energy_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.average_distance_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) Metrics#
core.units.mlip_unit._metrics.min_diff(pred_pos: torch.Tensor, dft_pos: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor)#

Calculate the minimum difference between predicted and target positions considering periodic boundary conditions.

core.units.mlip_unit._metrics.get_metrics_fn(function_name: str) Callable#