Extracting Hidden Layer Features from PET

PET computes rich intermediate representations at every stage of its forward pass — atom-type embeddings, embeddings of the neighborhood geometries, transformer outputs, and more. This tutorial shows how to retrieve any of those tensors using the mtt::feature:: prefix in the outputs dictionary passed to PET.forward(). Note that this only applies to non-exported models, i.e. checkpoints.

All outputs are returned as TensorMap objects, so they carry full sample metadata (atom or pair indices, cell-shift vectors) and are immediately compatible with the rest of the metatensor/metatomic ecosystem.

The tutorial is structured as follows:

  1. Standard outputs — backbone feature, last-layer features, and energy.

  2. Unprocessed backbone and last layer features — per-atom and per-pair tensors as used internally by the model.

  3. Raw featurizer inputs — distances, displacement vectors, and element-type indices, tied back to the water geometry.

  4. Discovering available paths — how to enumerate every capturable module path.

  5. Deep-dive into a GNN layer — capturing transformer internals and analysing how features evolve.

Setup: model and system

We initialise an untrained PET model (random weights) and build a single water molecule. The absolute values of the features will not be meaningful, but the shapes, sample labels, and structure of the returned TensorMaps are identical to those of a trained model.

Here we use the default hypers for the PET model, but make some modifications to better illustrate the dimensions of the captured layer outputs.

import ase
import torch
from metatomic.torch import ModelOutput, NeighborListOptions, systems_to_torch

from metatrain.pet import PET
from metatrain.utils.architectures import get_default_hypers
from metatrain.utils.data.dataset import DatasetInfo
from metatrain.utils.data.target_info import get_energy_target_info
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists


dtype = torch.float64

hypers = get_default_hypers("pet")["model"]

# Modify some hypers related to the token sizes
hypers["d_pet"] = 128
hypers["d_node"] = 256
hypers["d_head"] = 64

# Print some of the key default model hypers for reference.
print("Key PET hypers:")
print(f"  d_pet: {hypers['d_pet']}")
print(f"  d_node: {hypers['d_node']}")
print(f"  d_head: {hypers['d_head']}")
print(f"  num_gnn_layers: {hypers['num_gnn_layers']}")
print(f"  num_attention_layers: {hypers['num_attention_layers']}")
print(f"  featurizer_type: {hypers['featurizer_type']}")

dataset_info = DatasetInfo(
    length_unit="angstrom",
    atomic_types=[1, 6, 7, 8, 16],
    targets={"energy": get_energy_target_info("energy", dict(unit="eV"))},
)

model = PET(hypers=hypers, dataset_info=dataset_info).to(dtype)

# A single water molecule: O at the origin, H along x, H along y.
frames = [
    ase.Atoms(
        ["O", "H", "H"],
        positions=[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.5, 0.0]],
    )
]
systems = systems_to_torch(frames)

nl_options = NeighborListOptions(cutoff=4.5, full_list=True, strict=True)
systems = [
    get_system_with_neighbor_lists(system, [nl_options]).to(dtype) for system in systems
]
Key PET hypers:
  d_pet: 128
  d_node: 256
  d_head: 64
  num_gnn_layers: 2
  num_attention_layers: 2
  featurizer_type: feedforward

1. Standard model outputs

PET exposes three standard outputs: the backbone features ("feature"), the last-layer features before the readout ("mtt::aux::energy_last_layer_features"), and the predicted energy. These are the normal production outputs. Setting per_atom=False aggregates atom features into a single per-structure vector.

outputs = {
    "feature": ModelOutput(per_atom=True),
    "mtt::aux::energy_last_layer_features": ModelOutput(per_atom=True),
    "energy": ModelOutput(per_atom=False),
}
predictions = model(systems, outputs)

print(outputs, predictions)

backbone_block = predictions["feature"].block()
head_block = predictions["mtt::aux::energy_last_layer_features"].block()
energy_block = predictions["energy"].block()

print("backbone features shape:    ", backbone_block.values.shape)
print("last-layer features shape:  ", head_block.values.shape)
print("energy shape:               ", energy_block.values.shape)
print("backbone samples:           ", backbone_block.samples.names)
{'feature': <torch.ScriptObject object at 0x162b7ca0>, 'mtt::aux::energy_last_layer_features': <torch.ScriptObject object at 0x11e305e0>, 'energy': <torch.ScriptObject object at 0x11cf2170>} {'feature': TensorMap with 1 blocks
keys: _
      0, 'mtt::aux::energy_last_layer_features': TensorMap with 1 blocks
keys: _
      0, 'energy': TensorMap with 1 blocks
keys: _
      0}
backbone features shape:     torch.Size([3, 384])
last-layer features shape:   torch.Size([3, 128])
energy shape:                torch.Size([1, 1])
backbone samples:            ['system', 'atom']

The sample labels confirm that the backbone and last layer features are per-atom (["system", "atom"]), while the energy is per-structure (["system"]).

Note

These standard outputs go through PET’s full aggregation pipeline: edge contributions are summed over neighbors with cutoff weights applied, and the result is concatenated to the node stream. They are great for feature analysis but do not give direct access to the raw intermediate tensors used inside the GNN layers.

The backbone features have feature size 384, which due to the concatenated node (d_node=256) and edge (d_pet=128) features. The last layer features (128) have a dimension resulting from the concatenated node and edge last layer features (d_head = 64 for both).

2. Unprocessed backbone and head features

To access the raw per-atom and per-pair tensors as they are used inside the model — before neighbor-aggregation — use the mtt::feature:: prefix with the module path of the corresponding readout layer.

The integer index ("0" below) selects the readout layer. In the default "feedforward" featurizer mode all GNN outputs are combined into a single readout, so only index 0 exists. In "residual" mode each GNN layer has its own readout and the index selects which one. Here we are uses PET in the default feedforward mode.

outputs = {
    "mtt::feature::node_backbone.0": ModelOutput(per_atom=True),
    "mtt::feature::edge_backbone.0": ModelOutput(per_atom=True),
    "mtt::feature::node_heads.energy.0": ModelOutput(per_atom=True),
    "mtt::feature::edge_heads.energy.0": ModelOutput(per_atom=True),
}
predictions = model(systems, outputs)

node_bb = predictions["mtt::feature::node_backbone.0"].block()
edge_bb = predictions["mtt::feature::edge_backbone.0"].block()

print("node_backbone.0  shape:   ", node_bb.values.shape)
print("node_backbone.0  samples: ", node_bb.samples.names)
print()
print("edge_backbone.0  shape:   ", edge_bb.values.shape)
print("edge_backbone.0  samples: ", edge_bb.samples.names)
node_backbone.0  shape:    torch.Size([3, 256])
node_backbone.0  samples:  ['system', 'atom']

edge_backbone.0  shape:    torch.Size([6, 128])
edge_backbone.0  samples:  ['system', 'first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', 'cell_shift_c']

This highlights the two kinds of tensors returned by mtt::feature:::

Node-like tensors have shape (n_atoms, d) and samples ["system", "atom"] — one row per atom.

Edge-like tensors have shape (n_edges, d) and samples ["system", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"] — one row per directed pair. The cell-shift columns follow the standard metatensor neighbor-list convention.

3. Raw featurizer inputs

The very first quantities computed in a PET forward pass are the raw geometry descriptors fed into the featurizer. These can be captured directly:

  • edge_distances — scalar distance for each directed pair, shape (n_edges, 1)

  • edge_vectors — displacement vector r⁻ⱼ − rᵢ for each pair, shape (n_edges, 3)

  • element_indices_nodes — element-type index for each centre atom, shape (n_atoms, 1)

  • element_indices_neighbors — element-type index for each neighbor atom, shape (n_edges, 1)

outputs = {
    "mtt::feature::edge_distances": ModelOutput(per_atom=True),
    "mtt::feature::edge_vectors": ModelOutput(per_atom=True),
    "mtt::feature::element_indices_nodes": ModelOutput(per_atom=True),
    "mtt::feature::element_indices_neighbors": ModelOutput(per_atom=True),
}
predictions = model(systems, outputs)

distances = predictions["mtt::feature::edge_distances"].block()
vectors = predictions["mtt::feature::edge_vectors"].block()
node_types = predictions["mtt::feature::element_indices_nodes"].block()
nbr_types = predictions["mtt::feature::element_indices_neighbors"].block()

Let’s print the distances and tie them back to the geometry of the water molecule created at the start of this example. Our molecule has:

  • O–H₁: 1.000 Å (displacement [−1, 0, 0] from H₁ to O)

  • O–H₂: 0.500 Å (displacement [0, −0.5, 0] from H₂ to O)

  • H₁–H₂: √(1² + 0.5²) ≈ 1.118 Å

With a full neighbor list each pair appears twice (both directions), giving 6 directed edges in total.

print("Edge distances (Å):")
for i in range(distances.values.shape[0]):
    s = distances.samples[i]
    d = distances.values[i, 0].item()
    print(f"  {int(s['first_atom'])}{int(s['second_atom'])}:  {d:.4f} Å")
Edge distances (Å):
  0 → 1:  1.0000 Å
  0 → 2:  0.5000 Å
  1 → 0:  1.0000 Å
  1 → 2:  1.1180 Å
  2 → 0:  0.5000 Å
  2 → 1:  1.1180 Å
print("\nDisplacement vectors (Å):")
for i in range(vectors.values.shape[0]):
    s = vectors.samples[i]
    v = vectors.values[i].tolist()
    print(
        f"  {int(s['first_atom'])}{int(s['second_atom'])}:  "
        f"[{v[0]:+.3f}, {v[1]:+.3f}, {v[2]:+.3f}]"
    )
Displacement vectors (Å):
  0 → 1:  [+1.000, +0.000, +0.000]
  0 → 2:  [+0.000, +0.500, +0.000]
  1 → 0:  [-1.000, +0.000, +0.000]
  1 → 2:  [-1.000, +0.500, +0.000]
  2 → 0:  [+0.000, -0.500, +0.000]
  2 → 1:  [+1.000, -0.500, +0.000]
# Element indices correspond to the position of each atomic number in the
# sorted list of unique elements across the dataset.  Here atomic_types =
# [1, 6, 7, 8, 16], so H (Z=1) → index 0 and O (Z=8) → index 3.
print("\nElement indices per atom (0=H, 3=O):")
print(" ", node_types.values.long().squeeze(-1).tolist())
Element indices per atom (0=H, 3=O):
  [3, 0, 0]

4. Discovering available paths

Every sub-module visible in print(model) can be captured. The helper below formats the full list of valid mtt::feature:: keys in one go.

def print_all_module_paths(model):
    _skip = ("additive_models", "scaler", "long_range_featurizer")

    print(f"{'Module path':<65} {'Module type'}")
    print("-" * 95)
    for name, module in model.named_modules():
        if name and not any(name.startswith(p) for p in _skip):
            print(f"  mtt::feature::{name:<50}  {type(module).__name__}")


print_all_module_paths(model)
Module path                                                       Module type
-----------------------------------------------------------------------------------------------
  mtt::feature::gnn_layers                                          ModuleList
  mtt::feature::gnn_layers.0                                        CartesianTransformer
  mtt::feature::gnn_layers.0.trans                                  Transformer
  mtt::feature::gnn_layers.0.trans.layers                           ModuleList
  mtt::feature::gnn_layers.0.trans.layers.0                         TransformerLayer
  mtt::feature::gnn_layers.0.trans.layers.0.attention               AttentionBlock
  mtt::feature::gnn_layers.0.trans.layers.0.attention.input_linear  Linear
  mtt::feature::gnn_layers.0.trans.layers.0.attention.output_linear  Linear
  mtt::feature::gnn_layers.0.trans.layers.0.norm_attention          RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.0.norm_mlp                RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.0.mlp                     FeedForward
  mtt::feature::gnn_layers.0.trans.layers.0.mlp.w_in                Linear
  mtt::feature::gnn_layers.0.trans.layers.0.mlp.w_out               Linear
  mtt::feature::gnn_layers.0.trans.layers.0.mlp.activation          Identity
  mtt::feature::gnn_layers.0.trans.layers.0.center_contraction      Linear
  mtt::feature::gnn_layers.0.trans.layers.0.center_expansion        Linear
  mtt::feature::gnn_layers.0.trans.layers.0.norm_center_features    RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.0.center_mlp              FeedForward
  mtt::feature::gnn_layers.0.trans.layers.0.center_mlp.w_in         Linear
  mtt::feature::gnn_layers.0.trans.layers.0.center_mlp.w_out        Linear
  mtt::feature::gnn_layers.0.trans.layers.0.center_mlp.activation   Identity
  mtt::feature::gnn_layers.0.trans.layers.1                         TransformerLayer
  mtt::feature::gnn_layers.0.trans.layers.1.attention               AttentionBlock
  mtt::feature::gnn_layers.0.trans.layers.1.attention.input_linear  Linear
  mtt::feature::gnn_layers.0.trans.layers.1.attention.output_linear  Linear
  mtt::feature::gnn_layers.0.trans.layers.1.norm_attention          RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.1.norm_mlp                RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.1.mlp                     FeedForward
  mtt::feature::gnn_layers.0.trans.layers.1.mlp.w_in                Linear
  mtt::feature::gnn_layers.0.trans.layers.1.mlp.w_out               Linear
  mtt::feature::gnn_layers.0.trans.layers.1.mlp.activation          Identity
  mtt::feature::gnn_layers.0.trans.layers.1.center_contraction      Linear
  mtt::feature::gnn_layers.0.trans.layers.1.center_expansion        Linear
  mtt::feature::gnn_layers.0.trans.layers.1.norm_center_features    RMSNorm
  mtt::feature::gnn_layers.0.trans.layers.1.center_mlp              FeedForward
  mtt::feature::gnn_layers.0.trans.layers.1.center_mlp.w_in         Linear
  mtt::feature::gnn_layers.0.trans.layers.1.center_mlp.w_out        Linear
  mtt::feature::gnn_layers.0.trans.layers.1.center_mlp.activation   Identity
  mtt::feature::gnn_layers.0.edge_embedder                          Linear
  mtt::feature::gnn_layers.0.compress                               Sequential
  mtt::feature::gnn_layers.0.compress.0                             Linear
  mtt::feature::gnn_layers.0.compress.1                             SiLU
  mtt::feature::gnn_layers.0.compress.2                             Linear
  mtt::feature::gnn_layers.0.neighbor_embedder                      DummyModule
  mtt::feature::gnn_layers.1                                        CartesianTransformer
  mtt::feature::gnn_layers.1.trans                                  Transformer
  mtt::feature::gnn_layers.1.trans.layers                           ModuleList
  mtt::feature::gnn_layers.1.trans.layers.0                         TransformerLayer
  mtt::feature::gnn_layers.1.trans.layers.0.attention               AttentionBlock
  mtt::feature::gnn_layers.1.trans.layers.0.attention.input_linear  Linear
  mtt::feature::gnn_layers.1.trans.layers.0.attention.output_linear  Linear
  mtt::feature::gnn_layers.1.trans.layers.0.norm_attention          RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.0.norm_mlp                RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.0.mlp                     FeedForward
  mtt::feature::gnn_layers.1.trans.layers.0.mlp.w_in                Linear
  mtt::feature::gnn_layers.1.trans.layers.0.mlp.w_out               Linear
  mtt::feature::gnn_layers.1.trans.layers.0.mlp.activation          Identity
  mtt::feature::gnn_layers.1.trans.layers.0.center_contraction      Linear
  mtt::feature::gnn_layers.1.trans.layers.0.center_expansion        Linear
  mtt::feature::gnn_layers.1.trans.layers.0.norm_center_features    RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.0.center_mlp              FeedForward
  mtt::feature::gnn_layers.1.trans.layers.0.center_mlp.w_in         Linear
  mtt::feature::gnn_layers.1.trans.layers.0.center_mlp.w_out        Linear
  mtt::feature::gnn_layers.1.trans.layers.0.center_mlp.activation   Identity
  mtt::feature::gnn_layers.1.trans.layers.1                         TransformerLayer
  mtt::feature::gnn_layers.1.trans.layers.1.attention               AttentionBlock
  mtt::feature::gnn_layers.1.trans.layers.1.attention.input_linear  Linear
  mtt::feature::gnn_layers.1.trans.layers.1.attention.output_linear  Linear
  mtt::feature::gnn_layers.1.trans.layers.1.norm_attention          RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.1.norm_mlp                RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.1.mlp                     FeedForward
  mtt::feature::gnn_layers.1.trans.layers.1.mlp.w_in                Linear
  mtt::feature::gnn_layers.1.trans.layers.1.mlp.w_out               Linear
  mtt::feature::gnn_layers.1.trans.layers.1.mlp.activation          Identity
  mtt::feature::gnn_layers.1.trans.layers.1.center_contraction      Linear
  mtt::feature::gnn_layers.1.trans.layers.1.center_expansion        Linear
  mtt::feature::gnn_layers.1.trans.layers.1.norm_center_features    RMSNorm
  mtt::feature::gnn_layers.1.trans.layers.1.center_mlp              FeedForward
  mtt::feature::gnn_layers.1.trans.layers.1.center_mlp.w_in         Linear
  mtt::feature::gnn_layers.1.trans.layers.1.center_mlp.w_out        Linear
  mtt::feature::gnn_layers.1.trans.layers.1.center_mlp.activation   Identity
  mtt::feature::gnn_layers.1.edge_embedder                          Linear
  mtt::feature::gnn_layers.1.compress                               Sequential
  mtt::feature::gnn_layers.1.compress.0                             Linear
  mtt::feature::gnn_layers.1.compress.1                             SiLU
  mtt::feature::gnn_layers.1.compress.2                             Linear
  mtt::feature::gnn_layers.1.neighbor_embedder                      Embedding
  mtt::feature::combination_norms                                   ModuleList
  mtt::feature::combination_norms.0                                 LayerNorm
  mtt::feature::combination_norms.1                                 LayerNorm
  mtt::feature::combination_mlps                                    ModuleList
  mtt::feature::combination_mlps.0                                  Sequential
  mtt::feature::combination_mlps.0.0                                Linear
  mtt::feature::combination_mlps.0.1                                SiLU
  mtt::feature::combination_mlps.0.2                                Linear
  mtt::feature::combination_mlps.1                                  Sequential
  mtt::feature::combination_mlps.1.0                                Linear
  mtt::feature::combination_mlps.1.1                                SiLU
  mtt::feature::combination_mlps.1.2                                Linear
  mtt::feature::node_embedders                                      ModuleList
  mtt::feature::node_embedders.0                                    Embedding
  mtt::feature::edge_embedder                                       Embedding
  mtt::feature::node_heads                                          ModuleDict
  mtt::feature::node_heads.energy                                   ModuleList
  mtt::feature::node_heads.energy.0                                 Sequential
  mtt::feature::node_heads.energy.0.0                               Linear
  mtt::feature::node_heads.energy.0.1                               SiLU
  mtt::feature::node_heads.energy.0.2                               Linear
  mtt::feature::node_heads.energy.0.3                               SiLU
  mtt::feature::edge_heads                                          ModuleDict
  mtt::feature::edge_heads.energy                                   ModuleList
  mtt::feature::edge_heads.energy.0                                 Sequential
  mtt::feature::edge_heads.energy.0.0                               Linear
  mtt::feature::edge_heads.energy.0.1                               SiLU
  mtt::feature::edge_heads.energy.0.2                               Linear
  mtt::feature::edge_heads.energy.0.3                               SiLU
  mtt::feature::node_last_layers                                    ModuleDict
  mtt::feature::node_last_layers.energy                             ModuleList
  mtt::feature::node_last_layers.energy.0                           ModuleDict
  mtt::feature::node_last_layers.energy.0.energy___0                Linear
  mtt::feature::edge_last_layers                                    ModuleDict
  mtt::feature::edge_last_layers.energy                             ModuleList
  mtt::feature::edge_last_layers.energy.0                           ModuleDict
  mtt::feature::edge_last_layers.energy.0.energy___0                Linear
  mtt::feature::gnn_layers_post_mp_node                             ModuleList
  mtt::feature::gnn_layers_post_mp_node.0                           Identity
  mtt::feature::gnn_layers_post_mp_node.1                           Identity
  mtt::feature::gnn_layers_post_mp_edge                             ModuleList
  mtt::feature::gnn_layers_post_mp_edge.0                           Identity
  mtt::feature::gnn_layers_post_mp_edge.1                           Identity
  mtt::feature::node_backbone                                       ModuleList
  mtt::feature::node_backbone.0                                     Identity
  mtt::feature::edge_backbone                                       ModuleList
  mtt::feature::edge_backbone.0                                     Identity

Tip

Modules that return a tuple of (node_features, edge_features) — such as CartesianTransformer and TransformerLayer — require a _node or _edge suffix to select one element of the tuple:

"mtt::feature::gnn_layers.0_node"   # node output of GNN layer 0
"mtt::feature::gnn_layers.0_edge"   # edge output of GNN layer 0

The suffix is only needed when no module with that exact name exists; an AttributeError is raised if neither the exact path nor the suffix-stripped variant can be found.

5. Deep-dive into a GNN layer

We can hook any intermediate sub-module, including layers inside the first CartesianTransformer. Below we capture six tensors from different depths of the first GNN layer to see how the representations evolve from raw embeddings through the transformer stack.

outputs = {
    # Initial edge-type embedding (before the GNN)
    "mtt::feature::edge_embedder": ModelOutput(per_atom=True),
    # Edge embedding re-computed inside the first CartesianTransformer
    "mtt::feature::gnn_layers.0.edge_embedder": ModelOutput(per_atom=True),
    # Node and edge output of the first TransformerLayer
    "mtt::feature::gnn_layers.0.trans.layers.0_node": ModelOutput(per_atom=True),
    "mtt::feature::gnn_layers.0.trans.layers.0_edge": ModelOutput(per_atom=True),
    # MLP sub-module inside the same TransformerLayer (node-like)
    "mtt::feature::gnn_layers.0.trans.layers.0.mlp": ModelOutput(per_atom=True),
    # Full node output of the first CartesianTransformer
    "mtt::feature::gnn_layers.0_node": ModelOutput(per_atom=True),
}
predictions = model(systems, outputs)

Print the shape and nature (node/edge) of each captured tensor.

for key, tmap in predictions.items():
    block = tmap.block()
    kind = "node" if block.samples.names == ["system", "atom"] else "edge"
    print(
        f"{key[len('mtt::feature::') :]:45s}  {str(block.values.shape):20s}  ({kind})"
    )
edge_embedder                                  torch.Size([6, 128])  (edge)
gnn_layers.0.edge_embedder                     torch.Size([6, 128])  (edge)
gnn_layers.0.trans.layers.0.mlp                torch.Size([6, 128])  (edge)
gnn_layers.0.trans.layers.0_node               torch.Size([3, 256])  (node)
gnn_layers.0.trans.layers.0_edge               torch.Size([6, 128])  (edge)
gnn_layers.0_node                              torch.Size([3, 256])  (node)

Gallery generated by Sphinx-Gallery