from typing import List
from . import _dispatch
from ._classes import TensorBlock, TensorMap
from ._utils import _check_same_gradients_raise, _check_same_keys_raise
[docs]
def solve(X: TensorMap, Y: TensorMap) -> TensorMap:
    """Solve a linear system among two :py:class:`TensorMap`.
    Solve the linear equation set
    ``Y = X * w`` for the unknown ``w``.
    Where ``Y``, ``X`` and ``w`` are all :py:class:`TensorMap`.
    ``Y`` and ``X`` must have the same ``keys`` and
    all their :py:class:`TensorBlock` must be 2D-square array.
    :param X: a :py:class:`TensorMap` containing the "coefficient" matrices.
    :param Y: a :py:class:`TensorMap` containing the "dependent variable" values.
    :return: a :py:class:`TensorMap` with the same keys of ``Y`` and ``X``,
            and where each :py:class:`TensorBlock` has: the ``sample``
            equal to the ``properties`` of ``Y``;
            and the ``properties`` equal to the ``properties`` of ``X``.
    >>> import numpy as np
    >>> import metatensor
    >>> from metatensor import TensorBlock, TensorMap, Labels
    >>> np.random.seed(0)
    >>> # We construct two independent variables, each sampled at 100 random points
    >>> X_values = np.random.rand(100, 2)
    >>> true_c = np.array([10.0, 42.0])
    >>> # Build a linear function of the two variables, with coefficients defined
    >>> # in the true_c array, and add some random noise
    >>> y_values = (X_values @ true_c + np.random.normal(size=(100,))).reshape((100, 1))
    >>> covariance = X_values.T @ X_values
    >>> y_regression = X_values.T @ y_values
    >>> X = TensorMap(
    ...     keys=Labels(names=["dummy"], values=np.array([[0]])),
    ...     blocks=[
    ...         TensorBlock(
    ...             samples=Labels.range("sample", 2),
    ...             components=[],
    ...             properties=Labels.range("properties_for_regression", 2),
    ...             values=covariance,
    ...         )
    ...     ],
    ... )
    >>> y = TensorMap(
    ...     keys=Labels(names=["dummy"], values=np.array([[0]])),
    ...     blocks=[
    ...         TensorBlock(
    ...             samples=Labels.range("sample", 2),
    ...             components=[],
    ...             properties=Labels.range("property_to_regress", 1),
    ...             values=y_regression,
    ...         )
    ...     ],
    ... )
    >>> c = metatensor.solve(X, y)
    >>> print(c.block())
    TensorBlock
        samples (1): ['property_to_regress']
        components (): []
        properties (2): ['properties_for_regression']
        gradients: None
    >>> # c should now be close to true_c
    >>> print(c.block().values)
    [[ 9.67680334 42.12534656]]
    """
    _check_same_keys_raise(X, Y, "solve")
    for X_block in X.blocks():
        shape = X_block.values.shape
        if len(shape) != 2 or (not (shape[0] == shape[1])):
            raise ValueError(
                "the values in each block of X should be a square 2D array"
            )
    blocks: List[TensorBlock] = []
    for key, X_block in X.items():
        Y_block = Y.block(key)
        blocks.append(_solve_block(X_block, Y_block))
    return TensorMap(X.keys, blocks) 
def _solve_block(X: TensorBlock, Y: TensorBlock) -> TensorBlock:
    """
    Solve a linear system among two :py:class:`TensorBlock`.
    Solve the linear equation set X * w = Y for the unknown w.
    Where X , w, Y are all :py:class:`TensorBlock`
    """
    # TODO handle properties and samples not in the same order?
    if not X.samples == Y.samples:
        raise ValueError(
            "X and Y blocks in `solve` should have the same samples in the same order"
        )
    for X_component, Y_component in zip(X.components, Y.components):
        if X_component != Y_component:
            raise ValueError(
                "X and Y blocks in `solve` should have the same components \
                in the same order"
            )
    # reshape components together with the samples
    X_n_properties = X.values.shape[-1]
    X_values = X.values.reshape(-1, X_n_properties)
    Y_n_properties = Y.values.shape[-1]
    Y_values = Y.values.reshape(-1, Y_n_properties)
    _check_same_gradients_raise(X, Y, fname="solve")
    for parameter, X_gradient in X.gradients():
        X_gradient_values = X_gradient.values.reshape(-1, X_n_properties)
        X_values = _dispatch.concatenate((X_values, X_gradient_values), axis=0)
        Y_gradient = Y.gradient(parameter)
        Y_gradient_values = Y_gradient.values.reshape(-1, Y_n_properties)
        Y_values = _dispatch.concatenate((Y_values, Y_gradient_values), axis=0)
    weights = _dispatch.solve(X_values, Y_values)
    return TensorBlock(
        values=weights.T,
        samples=Y.properties,
        components=[],
        properties=X.properties,
    )