core.modules.evaluator#

Copyright (c) Meta, Inc. and its affiliates.

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

Attributes#

Classes#

Functions#

metrics_dict(→ Callable)

Wrap up the return of a metrics function

cosine_similarity(prediction, target[, key])

mae(→ torch.Tensor)

mse(→ torch.Tensor)

per_atom_mae(→ torch.Tensor)

per_atom_mse(→ torch.Tensor)

magnitude_error(→ torch.Tensor)

forcesx_mae(prediction, target[, key])

forcesx_mse(prediction, target[, key])

forcesy_mae(prediction, target[, key])

forcesy_mse(prediction, target[, key])

forcesz_mae(prediction, target[, key])

forcesz_mse(prediction, target[, key])

energy_forces_within_threshold(→ dict[str, float | int])

energy_within_threshold(→ dict[str, float | int])

average_distance_within_threshold(→ dict[str, float | int])

min_diff(pred_pos, dft_pos, cell, pbc)

rmse(→ dict[str, float | int])

Module Contents#

core.modules.evaluator.NONE_SLICE#
class core.modules.evaluator.Evaluator(task: str | None = None, eval_metrics: dict | None = None)#
task_metrics: ClassVar[dict[str, str]]#
task_primary_metric: ClassVar[dict[str, str | None]]#
task#
target_metrics#
eval(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], prev_metrics: dict | None = None)#
update(key, stat, metrics)#
core.modules.evaluator.metrics_dict(metric_fun: Callable) Callable#

Wrap up the return of a metrics function

core.modules.evaluator.cosine_similarity(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE)#
core.modules.evaluator.mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.modules.evaluator.mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.modules.evaluator.per_atom_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.modules.evaluator.per_atom_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE) torch.Tensor#
core.modules.evaluator.magnitude_error(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE, p: int = 2) torch.Tensor#
core.modules.evaluator.forcesx_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE)#
core.modules.evaluator.forcesx_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = NONE_SLICE)#
core.modules.evaluator.forcesy_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None)#
core.modules.evaluator.forcesy_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None)#
core.modules.evaluator.forcesz_mae(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None)#
core.modules.evaluator.forcesz_mse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None)#
core.modules.evaluator.energy_forces_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) dict[str, float | int]#
core.modules.evaluator.energy_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) dict[str, float | int]#
core.modules.evaluator.average_distance_within_threshold(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) dict[str, float | int]#
core.modules.evaluator.min_diff(pred_pos: torch.Tensor, dft_pos: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor)#
core.modules.evaluator.rmse(prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor], key: collections.abc.Hashable = None) dict[str, float | int]#