Source code for metatomic_ase._symmetry

from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from ._calculator import MetatomicCalculator


import ase  # isort: skip
import ase.calculators.calculator  # isort: skip


[docs] class SymmetrizedCalculator(ase.calculators.calculator.Calculator): r""" Take a MetatomicCalculator and average its predictions to make it (approximately) equivariant. Only predictions for energy, forces and stress are supported. The default is to average over a quadrature of the orthogonal group O(3) composed this way: - Lebedev quadrature of the unit sphere (S^2) - Equispaced sampling of the unit circle (S^1) - Both proper and improper rotations are taken into account by including the inversion operation (if ``include_inversion=True``) :param base_calculator: the MetatomicCalculator to be symmetrized :param l_max: the maximum spherical harmonic degree that the model is expected to be able to represent. This is used to choose the quadrature order. If ``0``, no rotational averaging will be performed (it can be useful to average only over the space group, see ``apply_group_symmetry``). :param batch_size: number of rotated systems to evaluate at once. If ``None``, all systems will be evaluated at once (this can lead to high memory usage). :param include_inversion: if ``True``, the inversion operation will be included in the averaging. This is required to average over the full orthogonal group O(3). :param apply_space_group_symmetry: if ``True``, the results will be averaged over discrete space group of rotations for the input system. The group operations are computed with `spglib <https://github.com/spglib/spglib>`_, and the average is performed after the O(3) averaging (if any). This has no effect for non-periodic systems. :param store_rotational_std: if ``True``, the results will contain the standard deviation over the different rotations for each property (e.g., ``energy_std``). """ implemented_properties = ["energy", "energies", "forces", "stress", "stresses"] def __init__( self, base_calculator: MetatomicCalculator, *, l_max: int = 3, batch_size: Optional[int] = None, include_inversion: bool = True, apply_space_group_symmetry: bool = False, store_rotational_std: bool = False, ) -> None: try: from scipy.integrate import lebedev_rule # noqa: F401 except ImportError as e: raise ImportError( "scipy is required to use the `SymmetrizedCalculator`, please install " "it with `pip install scipy` or `conda install scipy`" ) from e super().__init__() self.base_calculator = base_calculator if l_max > 131: raise ValueError( f"l_max={l_max} is too large, the maximum supported value is 131" ) self.l_max = l_max self.include_inversion = include_inversion if l_max > 0: lebedev_order, n_inplane_rotations = _choose_quadrature(l_max) self.quadrature_rotations, self.quadrature_weights = _get_quadrature( lebedev_order, n_inplane_rotations, include_inversion ) else: # no quadrature if include_inversion: # identity and inversion self.quadrature_rotations = np.array([np.eye(3), -np.eye(3)]) self.quadrature_weights = np.array([0.5, 0.5]) else: # only the identity self.quadrature_rotations = np.array([np.eye(3)]) self.quadrature_weights = np.array([1.0]) self.batch_size = ( batch_size if batch_size is not None else len(self.quadrature_rotations) ) self.store_rotational_std = store_rotational_std self.apply_space_group_symmetry = apply_space_group_symmetry
[docs] def calculate( self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] ) -> None: """ Perform the calculation for the given atoms and properties. :param atoms: the :py:class:`ase.Atoms` on which to perform the calculation :param properties: list of properties to compute, among ``energy``, ``forces``, and ``stress`` :param system_changes: list of changes to the system since the last call to ``calculate`` """ super().calculate(atoms, properties, system_changes) self.base_calculator.calculate(atoms, properties, system_changes) compute_forces_and_stresses = "forces" in properties or "stress" in properties per_atom = "energies" in properties if len(self.quadrature_rotations) > 0: rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) batches = [ rotated_atoms_list[i : i + self.batch_size] for i in range(0, len(rotated_atoms_list), self.batch_size) ] results: Dict[str, np.ndarray] = {} for batch in batches: try: batch_results = self.base_calculator.compute_energy( batch, compute_forces_and_stresses, per_atom=per_atom, ) for key, value in batch_results.items(): results.setdefault(key, []) results[key].extend( [value] if isinstance(value, float) else value ) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( "Out of memory error encountered during rotational averaging. " "Please reduce the batch size or use lower rotational " "averaging parameters. This can be done by setting the " "`batch_size` and `l_max` parameters while initializing the " "calculator." ) from e self.results.update( _compute_rotational_average( results, self.quadrature_rotations, self.quadrature_weights, self.store_rotational_std, ) ) if self.apply_space_group_symmetry: # Apply the discrete space group of the system a posteriori Q_list, P_list = _get_group_operations(atoms) self.results.update(_average_over_group(self.results, Q_list, P_list))
def _choose_quadrature(L_max: int) -> Tuple[int, int]: """ Choose a Lebedev quadrature order and number of in-plane rotations to integrate spherical harmonics up to degree ``L_max``. :param L_max: maximum spherical harmonic degree :return: (lebedev_order, n_inplane_rotations) """ # fmt: off available = [ 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95, 101, 107, 113, 119, 125, 131, ] # fmt: on # pick smallest order >= L_max n = min(o for o in available if o >= L_max) # minimal gamma count K = 2 * L_max + 1 return n, K def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: """ Create a list of copies of ``atoms``, rotated by each of the given ``rotations``. :param atoms: the :py:class:`ase.Atoms` to be rotated :param rotations: (N, 3, 3) array of orthogonal matrices :return: list of N :py:class:`ase.Atoms`, each rotated by the corresponding matrix """ rotated_atoms_list = [] has_cell = atoms.cell is not None and atoms.cell.rank > 0 for rot in rotations: new_atoms = atoms.copy() new_atoms.positions = new_atoms.positions @ rot.T if has_cell: new_atoms.set_cell( new_atoms.cell.array @ rot.T, scale_atoms=False, apply_constraint=False ) new_atoms.wrap() rotated_atoms_list.append(new_atoms) return rotated_atoms_list def _get_quadrature(lebedev_order: int, n_rotations: int, include_inversion: bool): """ Lebedev(S^2) x uniform angle quadrature on SO(3). If include_inversion=True, extend to O(3) by adding inversion * R. :param lebedev_order: order of the Lebedev quadrature on the unit sphere :param n_rotations: number of in-plane rotations per Lebedev node :param include_inversion: if ``True``, include the inversion operation in the quadrature :return: (N, 3, 3) array of orthogonal matrices, and (N,) array of weights associated to each matrix """ from scipy.integrate import lebedev_rule # Lebedev nodes (X: (3, M)) X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) beta = np.arccos(z) # (M,) # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) Rot = _rotations_from_angles(alpha, beta, gamma) R_so3 = Rot.as_matrix() # (N, 3, 3) # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) if not include_inversion: return R_so3, w_so3 # Extend to O(3) by appending inversion * R P = -np.eye(3) R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) w_o3 = np.concatenate([0.5 * w_so3, 0.5 * w_so3], axis=0) return R_o3, w_o3 def _rotations_from_angles(alpha, beta, gamma): from scipy.spatial.transform import Rotation # Build all combinations (alpha_i, beta_i, gamma_j) A = np.repeat(alpha, gamma.size).reshape(-1, 1) # (N, 1) B = np.repeat(beta, gamma.size).reshape(-1, 1) # (N, 1) G = np.tile(gamma, alpha.size).reshape(-1, 1) # (N, 1) # Compose ZYZ rotations in SO(3) Rot = ( Rotation.from_euler("z", A) * Rotation.from_euler("y", B) * Rotation.from_euler("z", G) ) return Rot def _compute_rotational_average(results, rotations, weights, store_std): R = rotations B = R.shape[0] w = weights w = w / w.sum() def _wreshape(x): return w.reshape((B,) + (1,) * (x.ndim - 1)) def _wmean(x): return np.sum(_wreshape(x) * x, axis=0) def _wstd(x): mu = _wmean(x) return np.sqrt(np.sum(_wreshape(x) * (x - mu) ** 2, axis=0)) out = {} # Energy (B,) if "energy" in results: E = np.asarray(results["energy"], dtype=float) # (B,) out["energy"] = _wmean(E) # () if store_std: out["energy_rot_std"] = _wstd(E) # () if "energies" in results: E = np.asarray(results["energies"], dtype=float) # (B,N) out["energies"] = _wmean(E) # (N,) if store_std: out["energies_rot_std"] = _wstd(E) # (N,) # Forces (B,N,3) from rotated structures: back-rotate with F' R if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) F_back = F @ R # F' R out["forces"] = _wmean(F_back) # (N,3) if store_std: out["forces_rot_std"] = _wstd(F_back) # (N,3) # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: S = np.asarray(results["stress"], dtype=float) # (B,3,3) RT = np.swapaxes(R, 1, 2) S_back = RT @ S @ R # R^T S' R out["stress"] = _wmean(S_back) # (3,3) if store_std: out["stress_rot_std"] = _wstd(S_back) # (3,3) if "stresses" in results: S = np.asarray(results["stresses"], dtype=float) # (B,N,3,3) RT = np.swapaxes(R, 1, 2) S_back = RT[:, None, :, :] @ S @ R[:, None, :, :] # R^T S' R out["stresses"] = _wmean(S_back) # (N,3,3) if store_std: out["stresses_rot_std"] = _wstd(S_back) # (N,3,3) return out def _get_group_operations( atoms: ase.Atoms, symprec: float = 1e-6, angle_tolerance: float = -1.0 ) -> Tuple[List[np.ndarray], List[np.ndarray]]: """ Extract point-group rotations Q_g (Cartesian, 3x3) and the corresponding atom-index permutations P_g (N x N) induced by the space-group operations. Returns Q_list, Cartesian rotation matrices of the point group, and P_list, permutation matrices mapping original indexing -> indexing after (R,t), :param atoms: input structure :param symprec: tolerance for symmetry finding :param angle_tolerance: tolerance for symmetry finding (in degrees). If less than 0, a value depending on ``symprec`` will be chosen automatically by spglib. :return: List of rotation matrices and permutation matrices. """ try: import spglib except ImportError as e: raise ImportError( "spglib is required to use the SymmetrizedCalculator with " "`apply_group_symmetry=True`. Please install it with " "`pip install spglib` or `conda install -c conda-forge spglib`" ) from e if not (atoms.pbc.all()): # No periodic boundary conditions: no symmetry return [], [] # Lattice with column vectors a1,a2,a3 (spglib expects (cell, frac, Z)) A = atoms.cell.array.T # (3,3) frac = atoms.get_scaled_positions() # (N,3) in [0,1) numbers = atoms.numbers N = len(atoms) data = spglib.get_symmetry_dataset( (atoms.cell.array, frac, numbers), symprec=symprec, angle_tolerance=angle_tolerance, _throw=True, ) if data is None: # No symmetry found return [], [] R_frac = data.rotations # (n_ops, 3,3), integer t_frac = data.translations # (n_ops, 3) Z = numbers # Match fractional coords modulo 1 within a tolerance, respecting chemical species def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): d = np.abs(frac_ref - x_new) # (N,3) d = np.minimum(d, 1.0 - d) # periodic distance # Mask by identical species mask = Z_ref == Z_i if not np.any(mask): raise RuntimeError("No matching species found while building permutation.") # Choose argmin over max-norm within species idx = np.where(mask)[0] j = idx[np.argmin(np.max(d[idx], axis=1))] # Sanity check if np.max(d[j]) > tol: raise RuntimeError( ( f"Sanity check failed in _match_index: max distance {np.max(d[j])} " f"exceeds tolerance {tol}." ) ) return j Q_list, P_list = [], [] seen = set() Ainv = np.linalg.inv(A) for Rf, tf in zip(R_frac, t_frac, strict=True): # Cartesian rotation: Q = A Rf A^{-1} Q = A @ Rf @ Ainv # Deduplicate rotations (point group) by rounding key = tuple(np.round(Q.flatten(), 12)) if key in seen: continue seen.add(key) # Build the permutation P from i to j P = np.zeros((N, N), dtype=int) new_frac = (frac @ Rf.T + tf) % 1.0 # images after (Rf,tf) for i in range(N): j = _match_index(new_frac[i], frac, Z, Z[i]) P[j, i] = 1 # column i maps to row j Q_list.append(Q.astype(float)) P_list.append(P) return Q_list, P_list def _average_over_group( results: dict, Q_list: List[np.ndarray], P_list: List[np.ndarray] ) -> dict: """ Apply the point-group projector in output space. :param results: Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or 'stress' (3,3). These are predictions for the current structure in the reference frame. :param Q_list: Rotation matrices of the point group, from :py:func:`_get_group_operations` :param P_list: Permutation matrices of the point group, from :py:func:`_get_group_operations` :return out: Projected quantities. """ m = len(Q_list) if m == 0: return results # nothing to do out = {} # Energy: unchanged by the projector (scalar) if "energy" in results: out["energy"] = float(results["energy"]) # Forces: (N,3) row-vectors; projector: (1/|G|) \sum_g P_g^T F Q_g if "forces" in results: F = np.asarray(results["forces"], float) if F.ndim != 2 or F.shape[1] != 3: raise ValueError(f"'forces' must be (N,3), got {F.shape}") acc = np.zeros_like(F) for Q, P in zip(Q_list, P_list, strict=True): acc += P.T @ (F @ Q) out["forces"] = acc / m # Stress: (3,3); projector: (1/|G|) \sum_g Q_g^T S Q_g if "stress" in results: S = np.asarray(results["stress"], float) if S.shape != (3, 3): raise ValueError(f"'stress' must be (3,3), got {S.shape}") # S = 0.5 * (S + S.T) # symmetrize just in case acc = np.zeros_like(S) for Q in Q_list: acc += Q.T @ S @ Q S_pg = acc / m out["stress"] = S_pg return out