Source code for metatensor.operations.one_hot

from . import _dispatch
from ._backend import Labels, torch_jit_script


[docs] @torch_jit_script def one_hot(labels: Labels, dimension: Labels): """Generates one-hot encoding from a Labels object. This function takes two ``Labels`` objects as inputs. The first is the one to be converted to one-hot-encoded format, and the second contains the name of the label to be extracted and all possible values of the one-hot encoding. :param labels: A ``Labels`` object from which one label will be extracted and transformed into a one-hot-encoded array. :param dimension: A ``Labels`` object that contains a single dimension. The name of this label is the same that will be selected from ``labels``, and its values correspond to all possible values that the label can take. :return: A two-dimensional ``numpy.ndarray`` or ``torch.Tensor`` containing the one-hot encoding along the selected dimension: its first dimension matches the one in ``labels``, while the second contains 1 at the position corresponding to the original label and 0 everywhere else >>> import numpy as np >>> import metatensor >>> from metatensor import Labels >>> # Let's say we have 6 atoms, whose chemical identities >>> # are C, H, H, H, C, H: >>> original_labels = Labels( ... names=["atom", "types"], ... values=np.array([[0, 6], [1, 1], [2, 1], [3, 1], [4, 6], [5, 1]]), ... ) >>> # Set up a Labels object with all possible elements, >>> # including, for example, also O: >>> possible_labels = Labels(names=["types"], values=np.array([[1], [6], [8]])) >>> # Get the one-hot encoded labels: >>> one_hot_encoding = metatensor.one_hot(original_labels, possible_labels) >>> print(one_hot_encoding) [[0 1 0] [1 0 0] [1 0 0] [1 0 0] [0 1 0] [1 0 0]] """ if len(dimension.names) != 1: raise ValueError( "only one label dimension can be extracted as one-hot " "encoding. The `dimension` labels contains " f"{len(dimension.names)} names" ) name = dimension.names[0] indices = _dispatch.zeros_like(dimension.values, [len(labels)]) labels_name = labels.column(name) for i in range(labels_name.shape[0]): entry = labels_name[None, i] position = dimension.position(entry) if position is None: raise ValueError( f"{name}={entry[0]} is present in the labels, but was not found in " "the dimension" ) indices[i] = position one_hot_array = _dispatch.eye_like(dimension.values, len(dimension))[ _dispatch.to_index_array(indices) ] return one_hot_array