Testing Utilities¶
The base class for tests¶
- class metatrain.utils.testing.ArchitectureTests[source]¶
Bases:
objectThis is the base class for all architecture tests.
It doesn’t implement any tests itself, but provides fixtures and helper functions that are generally useful for testing architectures.
Child classes can override everything, including fixtures, to make the tests suit their needs. Note that some fixtures defined here depend on other fixtures, but when overriding them, you can change completely their signature.
- architecture: str¶
Name of the architecture to be tested.
Based on this, the test suite will find the model and trainer classes as well as the hyperparameters.
- dataset_path() str[source]¶
Fixture that provides a path to a dataset file for testing.
- Returns:
The path to the dataset file.
- Return type:
- dataset_targets(dataset_path: str) dict[str, dict][source]¶
Fixture that provides the target hyperparameters for the dataset used in testing.
- get_dataset(dataset_targets: dict[str, dict], dataset_path: str) tuple[Dataset, dict[str, TargetInfo], DatasetInfo][source]¶
Helper function to load the dataset used in testing.
- device(request: FixtureRequest) device[source]¶
Fixture to provide the torch device for testing.
- Parameters:
request (FixtureRequest) – The pytest request fixture.
- Returns:
The torch device to be used.
- Return type:
- dtype(request: FixtureRequest) dtype[source]¶
Fixture to provide the model data type for testing.
- Parameters:
request (FixtureRequest) – The pytest request fixture.
- Returns:
The torch data type to be used.
- Return type:
- dataset_info() DatasetInfo[source]¶
Fixture that provides a basic
DatasetInfowith an energy target for testing.- Returns:
A
DatasetInfoinstance with an energy target.- Return type:
- per_atom(request: FixtureRequest) bool[source]¶
Fixture to test both per-atom and per-system targets.
- Parameters:
request (FixtureRequest) – The pytest request fixture.
- Returns:
Whether the target is per-atom or not.
- Return type:
- dataset_info_scalar(per_atom: bool) DatasetInfo[source]¶
Fixture that provides a basic
DatasetInfowith a scalar target for testing.- Parameters:
per_atom (bool) – Whether the target is per-atom or not.
- Returns:
A
DatasetInfoinstance with a scalar target.- Return type:
- dataset_info_vector(per_atom: bool) DatasetInfo[source]¶
Fixture that provides a basic
DatasetInfowith a vector target for testing.- Parameters:
per_atom (bool) – Whether the target is per-atom or not.
- Returns:
A
DatasetInfoinstance with a vector target.- Return type:
- o3_lambda(request: FixtureRequest) int[source]¶
Fixture to provide different O(3) lambda values for testing spherical tensors.
- Parameters:
request (FixtureRequest) – The pytest request fixture.
- Returns:
The O(3) lambda value.
- Return type:
- o3_sigma(request: FixtureRequest) int[source]¶
Fixture to provide different O(3) sigma values for testing spherical tensors.
- Parameters:
request (FixtureRequest) – The pytest request fixture.
- Returns:
The O(3) sigma value.
- Return type:
- dataset_info_spherical(o3_lambda: int, o3_sigma: int) DatasetInfo[source]¶
Fixture that provides a basic
DatasetInfowith a spherical target for testing.- Parameters:
- Returns:
A
DatasetInfoinstance with a spherical target.- Return type:
- dataset_info_multispherical(per_atom: bool) DatasetInfo[source]¶
Fixture that provides a basic
DatasetInfowith multiple spherical targets for testing.- Parameters:
per_atom (bool) – Whether the target is per-atom or not.
- Returns:
A
DatasetInfoinstance with a multiple spherical targets.- Return type:
- property trainer_cls: type[TrainerInterface]¶
The trainer class to be tested.
- default_hypers() dict[source]¶
Fixture that provides the default hyperparameters for testing.
- Returns:
The default hyperparameters for the architecture.
- Return type:
Test suites for each functionality¶
- class metatrain.utils.testing.AutogradTests[source]¶
Bases:
ArchitectureTestsTests that autograd works correctly for a given model.
- cuda_nondet_tolerance = 0.0¶
Some operations in your model might be nondeterministic in CuBLAS.
This can result in small differences in two gradient computations with the same input and outputs. This number sets the nondeterministic tolerance for
gradcheckandgradgradcheckwhen running on CUDA.
- test_autograd_cell(device: device, model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that autograd can compute gradients with respect to the cell.
It checks both first and second derivatives.
It uses
torch.autograd.gradcheckandtorch.autograd.gradgradcheckfor this purpose.- Parameters:
device (device) – The device to run the test on.
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
- Return type:
None
- test_autograd_positions(device: device, model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that autograd can compute gradients with respect to positions.
It checks both first and second derivatives.
It uses
torch.autograd.gradcheckandtorch.autograd.gradgradcheckfor this purpose.- Parameters:
device (device) – The device to run the test on.
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
- Return type:
None
- class metatrain.utils.testing.CheckpointTests[source]¶
Bases:
ArchitectureTestsTest suite for model and trainer checkpoints.
This test suite verifies that the checkpoints for the architecture follow the expected behavior of
metatraincheckpoints.- incompatible_trainer_checkpoints: list[str] = []¶
A list of checkpoint paths that are known to be incompatible with the current trainer version when restarting.
This should be overriden in subclasses.
- model_trainer(dataset_path: str, dataset_targets: dict, minimal_model_hypers: dict, default_hypers: dict) tuple[ModelInterface, TrainerInterface][source]¶
Fixture that returns a trained model and trainer.
The model and trainer are used in the test suite to verify checkpoint functionality.
- Parameters:
dataset_path (str) – The path to the dataset file to train on.
dataset_targets (dict) – The targets that the dataset contains.
minimal_model_hypers (dict) – Hyperparameters to initialize the model. These should give the smallest possible model to use as little disk space as possible when saving checkpoints.
default_hypers (dict) – Default hyperparameters to initialize the trainer.
- Returns:
A tuple containing the trained model and the trainer.
- Return type:
- test_checkpoint_did_not_change(monkeypatch: Any, tmp_path: str, model_trainer: tuple[ModelInterface, TrainerInterface]) None[source]¶
Test that the checkpoint did not change.
This test gets the current version of the model and trainer, and loads the checkpoint for that version from the
checkpoints/folder. If that checkpoint is not compatible with the current code, this means that the checkpoint version of either the model or the trainer needs to be bumped.- Parameters:
monkeypatch (Any) – The pytest monkeypatch fixture.
tmp_path (str) – The pytest tmp_path fixture.
model_trainer (tuple[ModelInterface, TrainerInterface]) – Model and trainer to test.
- Return type:
None
- test_failed_checkpoint_upgrade(cls_type: Literal['model', 'trainer']) None[source]¶
Test error raised when trying to upgrade an invalid checkpoint version.
This test creates a checkpoint with an invalid version number and tries to upgrade it using the corresponding class.
If this test fails, it likely means that you are not raising the error in your model/trainer’s
upgrade_checkpointmethod when the checkpoint version is not recognized. To raise the appropiate error:cls_type = "model" # or "trainer" raise RuntimeError( f"Unable to upgrade the checkpoint: the checkpoint is using {cls_type} " f"version {checkpoint_version}, while the current {cls_type} version " f"is {self.__class__.__checkpoint_version__}." )
- Parameters:
cls_type (Literal['model', 'trainer']) – The class type to test.
- Return type:
None
- test_get_checkpoint(context: Literal['finetune', 'restart', 'export'], caplog: Any, model_trainer: tuple[ModelInterface, TrainerInterface]) None[source]¶
Test that the checkpoint created by the
model.get_checkpoint()function can be loaded back in all possible contexts.This test can fail either if the model is unable to produce checkpoints, or if the generated checkpoint can’t be loaded back by the model in the specified context.
- Parameters:
context (Literal['finetune', 'restart', 'export']) – The context in which to load the generated checkpoint.
caplog (Any) – The pytest caplog fixture.
model_trainer (tuple[ModelInterface, TrainerInterface]) – Model and trainer to be used for the test.
- Return type:
None
- test_loading_old_checkpoints(default_hypers: dict, model_trainer: tuple[ModelInterface, TrainerInterface], context: Literal['restart', 'finetune', 'export']) None[source]¶
Tests that checkpoints from previous versions can be loaded.
This test goes through all the checkpoint files in the
checkpoints/folder of the current directory (presumably the architecture’s tests folder) and tries to load them in the current model and trainer.The test skips trainer checkpoints that are listed in this class’s
incompatible_trainer_checkpointsattribute when the context isrestart.- Parameters:
default_hypers (dict) – Default hyperparameters to initialize the trainer.
model_trainer (tuple[ModelInterface, TrainerInterface]) – Model and trainer to be used for loading the checkpoints.
context (Literal['restart', 'finetune', 'export']) – The context in which to load the checkpoint.
- Return type:
None
- class metatrain.utils.testing.ExportedTests[source]¶
Bases:
ArchitectureTestsTest suite to test exported models.
- test_to(device: device, dtype: dtype, model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the .to() method of the exported model works.
In other words, it tests that the exported model can be moved to different devices and dtypes.
- Parameters:
device (device) – The device to move the exported model to.
dtype (dtype) – The dtype to move the exported model to.
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
- Return type:
None
- class metatrain.utils.testing.InputTests[source]¶
Bases:
ArchitectureTestsTest suite to check that the model handles inputs correctly.
- test_fixed_composition_weights(default_hypers: dict) None[source]¶
Test that the trainer can accept fixed composition weights.
The tests checks that when providing valid fixed composition weights, the architecture options are accepted.
This test is skipped if the architecture’s trainer does not use
fixed_composition_weights.If this test is failing you need to add the correct type hint to the
fixed_composition_weightsfield of the trainer hypers. I.e., indocumentation.pyof your architecture:from typing_extensions import TypedDict from metatrain.utils.additive import FixedCompositionWeights class TrainerHypers(TypedDict): ... # Rest of hyperparameters fixed_composition_weights: FixedCompositionWeights
with the appropiate documentation and default if applicable.
- Parameters:
default_hypers (dict) – The default hyperparameters for the architecture.
- Return type:
None
- test_fixed_composition_weights_error(default_hypers: dict) None[source]¶
Test that invalid input is not accepted for
fixed_composition_weights.The tests checks that when providing invalid fixed composition weights, the architecture options raise a validation error.
This test is skipped if the architecture’s trainer does not use
fixed_composition_weights.If this test is failing you need to add the correct type hint to the
fixed_composition_weightsfield of the trainer hypers. I.e., indocumentation.pyof your architecture:from typing_extensions import TypedDict from metatrain.utils.additive import FixedCompositionWeights class TrainerHypers(TypedDict): ... # Rest of hyperparameters fixed_composition_weights: FixedCompositionWeights
with the appropiate documentation and default if applicable.
- Parameters:
default_hypers (dict) – The default hyperparameters for the architecture.
- Return type:
None
- class metatrain.utils.testing.OutputTests[source]¶
Bases:
ArchitectureTestsTest suite to check that the model can produce different types of outputs.
If a model does not support a given type of output, set the corresponding
supports_*_outputsattribute toFalseto skip the corresponding tests. By default, they are all set toTrueto avoid supported outputs from being untested accidentally.- is_equivariant_model: bool = True¶
Whether the model is equivariant (i.e. produces outputs that transform correctly under rotations by architecture’s design).
- n_features() int | None[source]¶
Fixture that returns the number of features produced by the model.
By default this is set to
None, which skips checking the number of features in the output. Override this fixture for your architecture if you want the test suite to check that the number of features in the output is correct.- Returns:
The number of features produced by the model.
- Return type:
int | None
- n_last_layer_features() int | None[source]¶
Fixture that returns the number of last-layer features produced by the model.
By default this is set to
None, which skips checking the number of last-layer features in the output. Override this fixture for your architecture if you want the test suite to check that the number of last-layer features in the output is correct.- Returns:
The number of last-layer features produced by the model.
- Return type:
int | None
- single_atom_energy() float | None[source]¶
Fixture that returns the single atom energy value.
By default this is set to
None, which skips checking the single atom energy value in the output. Override this fixture for your architecture if you want the test suite to check that the single atom energy value in the output is correct.- Returns:
The single atom energy value.
- Return type:
float | None
- supports_last_layer_features: bool = True¶
Whether the model supports returning last-layer features.
- supports_selected_atoms: bool = True¶
Whether the model supports the
selected_atomsargument in theforward()method.
- test_output_features(model_hypers: dict, dataset_info: DatasetInfo, per_atom: bool, n_features: int | None) None[source]¶
Tests that the model can output its learned features.
If this test is failing you are probably not exposing correctly the features output in your model.
This test is skipped if the model does not support features output, i.e., if
supports_featuresis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
per_atom (bool) – Whether to request per-atom features or not.
n_features (int | None) – Expected number of features. If
None, the number of features is not checked.
- Return type:
None
- test_output_last_layer_features(model_hypers: dict, dataset_info: DatasetInfo, per_atom: bool, n_last_layer_features: int | None) None[source]¶
Tests that the model can output its last layer features.
If this test is failing you are probably not exposing correctly the last-layer features output in your model.
This test is skipped if the model does not support last-layer features output, i.e., if
supports_last_layer_featuresis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
per_atom (bool) – Whether to request per-atom last-layer features or not.
n_last_layer_features (int | None) – Expected number of last-layer features. If
None, the number of last-layer features is not checked.
- Return type:
None
- test_output_last_layer_features_selected_atoms(model_hypers: dict, dataset_info: DatasetInfo, dataset_path: str, select_atoms: list[int]) None[source]¶
Tests that the model can output its last layer features for selected atoms.
This test is skipped if the model does not support last-layer features or the model does not support the
selected_atomsargument of theforward()method, i.e. if eithersupports_last_layer_featuresorsupports_selected_atomsis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
dataset_path (str) – Path to a dataset file to read systems from.
select_atoms (list[int]) – List of atom indices to select for the output.
- Return type:
None
- test_output_multispherical(model_hypers: dict, dataset_info_multispherical: DatasetInfo, per_atom: bool) None[source]¶
Tests that forward pass works for spherical tensor outputs with multiple irreps.
It also tests that the returned outputs have the expected samples and values shape.
This test is skipped if the model does not support spherical outputs, i.e., if
supports_spherical_outputsis set toFalse.If this test is failing and
test_output_sphericalis passing, your model probably is not handling the possibility that spherical outputs can have multiple irreps.If
test_output_sphericalis also failing, fix that test first.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_multispherical (DatasetInfo) – Dataset information with multiple spherical outputs.
per_atom (bool) – Whether the requested outputs are per-atom or not.
- Return type:
None
- test_output_scalar(model_hypers: dict, dataset_info_scalar: DatasetInfo, per_atom: bool) None[source]¶
Tests that forward pass works for scalar outputs.
It also tests that the returned outputs have the expected samples and values shape.
This test is skipped if the model does not support scalar outputs, i.e., if
supports_scalar_outputsis set toFalse.If this test is failing, your model might:
not be producing scalar outputs when requested.
not be taking into account correctly the
per_atomfield of the outputs passed to theoutputsargument of theforward()method.
- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_scalar (DatasetInfo) – Dataset information with scalar outputs.
per_atom (bool) – Whether the requested outputs are per-atom or not.
- Return type:
None
- test_output_scalar_invariant(model_hypers: dict, dataset_info: DatasetInfo, dataset_path: str) None[source]¶
Tests that scalar outputs are invariant to rotation.
This test is skipped if the model does not support scalar outputs, or if the model is not equivariant by design, i.e., if either
supports_scalar_outputsoris_equivariant_modelis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
dataset_path (str) – Path to a dataset file to read systems from.
- Return type:
None
- test_output_spherical(model_hypers: dict, dataset_info_spherical: DatasetInfo, per_atom: bool) None[source]¶
Tests that forward pass works for spherical outputs.
It also tests that the returned outputs have the expected samples and values shape.
This test is skipped if the model does not support spherical outputs, i.e., if
supports_spherical_outputsis set toFalse.If this test is failing, your model might: - not be producing spherical outputs when requested. - not be taking into account correctly the
per_atomfield of the outputs passed to theoutputsargument of theforward()method.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_spherical (DatasetInfo) – Dataset information with spherical outputs.
per_atom (bool) – Whether the requested outputs are per-atom or not.
- Return type:
None
- test_output_spherical_equivariant_inversion(model_hypers: dict, dataset_info_spherical: DatasetInfo, dataset_path: str, o3_lambda: int, o3_sigma: int) None[source]¶
Tests that the model is equivariant with respect to inversions.
This test is done on spherical targets (not scalar targets).
This test is skipped if the model does not support spherical outputs, or if the model is not equivariant by design, i.e., if either
supports_spherical_outputsoris_equivariant_modelis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_spherical (DatasetInfo) – Dataset information with spherical outputs.
dataset_path (str) – Path to a dataset file to read systems from.
o3_lambda (int) – The O(3) lambda of the spherical output to test.
o3_sigma (int) – The O(3) sigma of the spherical output to test.
- Return type:
None
- test_output_spherical_equivariant_rotations(model_hypers: dict, dataset_info_spherical: DatasetInfo, dataset_path: str) None[source]¶
Tests that the model is rotationally equivariant when predicting spherical tensors.
This test is skipped if the model does not support spherical outputs, or if the model is not equivariant by design, i.e., if either
supports_spherical_outputsoris_equivariant_modelis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_spherical (DatasetInfo) – Dataset information with spherical outputs.
dataset_path (str) – Path to a dataset file to read systems from.
- Return type:
None
- test_output_vector(model_hypers: dict, dataset_info_vector: DatasetInfo, per_atom: bool) None[source]¶
Tests that forward pass works for vector outputs.
It also tests that the returned outputs have the expected samples and values shape.
This test is skipped if the model does not support vector outputs, i.e., if
supports_vector_outputsis set toFalse.If this test is failing, your model might: - not be producing vector outputs when requested. - not be taking into account correctly the
per_atomfield of the outputs passed to theoutputsargument of theforward()method.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_vector (DatasetInfo) – Dataset information with vector outputs.
per_atom (bool) – Whether the requested outputs are per-atom or not.
- Return type:
None
- test_prediction_energy_subset_atoms(model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the model can predict on a subset of the atoms in a system.
This test checks that the model supports the
selected_atomsargument of theforward()method, and it handles it correctly. That is, the model only returns outputs for the selected atoms.This test is skipped if the model does not support the
selected_atomsargument of theforward()method, i.e., ifsupports_selected_atomsis set toFalse.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
- Return type:
None
- test_prediction_energy_subset_elements(model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the model can predict on a subset of the elements it was trained on.
If this test is failing, it means that your model needs each system to contain all the elements that are present in the dataset. If this is the expected behavior for your model, we need to introduce a variable to skip this test, contact the
metatraindevelopers.- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
- Return type:
None
- test_single_atom(model_hypers: dict, dataset_info: DatasetInfo, single_atom_energy: float | None) None[source]¶
Tests that the model runs fine on a single atom system.
- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset information to initialize the model.
single_atom_energy (float | None) – Expected single atom energy value. If
None, the single atom energy value is not checked.
- Return type:
None
- class metatrain.utils.testing.TorchscriptTests[source]¶
Bases:
ArchitectureTestsTest suite to check that architectures can be jit compiled with TorchScript.
- float_hypers: list[str] = []¶
List of hyperparameter keys (dot-separated for nested keys) that are floats. A test will set these to integers to test that TorchScript compilation works in that case.
- test_torchscript(model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the model can be jitted.
If this test fails it probably means that there is some code in the model that is not compatible with TorchScript. The exception raised by the test should indicate where the problem is.
- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset to initialize the model.
- Return type:
None
- test_torchscript_integers(model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the model can be jitted when some float parameters are instead supplied as integers.
- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset to initialize the model.
- Return type:
None
- test_torchscript_save_load(tmpdir: Any, model_hypers: dict, dataset_info: DatasetInfo) None[source]¶
Tests that the model can be jitted, saved and loaded.
- Parameters:
tmpdir (Any) – Temporary directory where to save the model.
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info (DatasetInfo) – Dataset to initialize the model.
- Return type:
None
- test_torchscript_spherical(model_hypers: dict, dataset_info_spherical: DatasetInfo) None[source]¶
Tests that there is no problem with jitting with spherical targets.
- Parameters:
model_hypers (dict) – Hyperparameters to initialize the model.
dataset_info_spherical (DatasetInfo) – Dataset to initialize the model (containing spherical targets).
- Return type:
None
- class metatrain.utils.testing.TrainingTests[source]¶
Bases:
ArchitectureTestsPuts architectures to test in real training scenarios.
- test_continue(monkeypatch: Any, tmp_path: Path, dataset_path: str, dataset_targets: dict[str, dict], default_hypers: dict[str, Any], model_hypers: dict[str, Any]) None[source]¶
Tests that a model can be checkpointed and loaded for a continuation of the training process
- Parameters:
monkeypatch (Any) – Pytest fixture to modify the current working directory.
tmp_path (Path) – Temporary path to use for saving checkpoints.
dataset_path (str) – Path to the dataset to use for training.
dataset_targets (dict[str, dict]) – Target hypers for the targets in the dataset.
default_hypers (dict[str, Any]) – Default hyperparameters for the architecture.
model_hypers (dict[str, Any]) – Hyperparameters to initialize the model.
- Return type:
None