TensorBlock

using metatensor_torch::TorchTensorBlock = torch::intrusive_ptr<TensorBlockHolder>

TorchScript will always manipulate TensorBlockHolder through a torch::intrusive_ptr

class TensorBlockHolder : public CustomClassHolder

Wrapper around metatensor::TensorBlock for integration with TorchScript

Python/TorchScript code will typically manipulate torch::intrusive_ptr<TensorBlockHolder> (i.e. TorchTensorBlock) instead of instances of TensorBlockHolder.

Public Functions

TensorBlockHolder(torch::Tensor data, TorchLabels samples, std::vector<TorchLabels> components, TorchLabels properties)

Create a new TensorBlockHolder with the given data and metadata.

TensorBlockHolder(metatensor::TensorBlock block, torch::IValue parent)

Create a torch TensorBlockHolder from a pre-existing metatensor::TensorBlock.

If the block is a view inside another TorchTensorBlock or TorchTensorMap, then parent should point to the corresponding object, making sure a reference to it is kept around.

TorchTensorBlock copy() const

Make a copy of this TensorBlockHolder, including all the data contained inside

torch::Tensor values() const

Get a view in the values in this block.

TorchLabels labels(uintptr_t axis) const

Get the labels in this block associated with either "values" or one gradient (by setting values_gradients to the gradient parameter); in the given axis.

inline TorchLabels samples() const

Access the sample Labels for this block.

The entries in these labels describe the first dimension of the values() array.

inline int64_t len() const

Get the length of this block, i.e. the number of samples.

inline at::IntArrayRef shape() const

Get the shape of the values Tensor.

inline std::vector<TorchLabels> components() const

Access the component Labels for this block.

The entries in these labels describe intermediate dimensions of the values() array.

inline TorchLabels properties() const

Access the property Labels for this block.

The entries in these labels describe the last dimension of the values() array. The properties are guaranteed to be the same for values and gradients in the same block.

void add_gradient(const std::string &parameter, TorchTensorBlock gradient)

Add a set of gradients with respect to parameters in this block.

Parameters:
  • parameter – add gradients with respect to this parameter (e.g. "positions", "cell", …)

  • gradient – a TorchTensorBlock whose values contain the gradients with respect to the parameter. The labels of the gradient TorchTensorBlock should be organized as follows: its samples must contain "sample" as the first label, which establishes a correspondence with the samples of the original TorchTensorBlock; its components must contain at least the same components as the original TorchTensorBlock, with any additional component coming before those; its properties must match those of the original TorchTensorBlock.

inline std::vector<std::string> gradients_list() const

Get a list of all gradients defined in this block.

bool has_gradient(const std::string &parameter) const

Check if a given gradient is defined in this TensorBlock.

inline torch::Device device() const

Get the device for the values stored in this TensorBlock

inline torch::Dtype scalar_type() const

Get the dtype for the values stored in this TensorBlock

TorchTensorBlock to(torch::optional<torch::Dtype> dtype = torch::nullopt, torch::optional<torch::Device> device = torch::nullopt) const

Move all arrays in this block to the given dtype and device.

TorchTensorBlock to_positional(torch::IValue positional_1, torch::IValue positional_2, torch::optional<torch::Dtype> dtype, torch::optional<torch::Device> device, torch::optional<std::string> arrays) const

Wrapper of the to function to enable using it with positional parameters from Python; for example to(dtype), to(device), to(dtype, device=device), to(dtype, device), to(device, dtype), etc.

arrays is left as a keyword argument since it is mainly here for compatibility with the pure Python backend, and only "torch" is supported.

std::string repr() const

Implementation of repr/__str__ for Python.

inline const metatensor::TensorBlock &as_metatensor() const

Get the underlying metatensor TensorBlock.

void save(const std::string &path) const

Serialize and save a TensorBlock to the given path.

torch::Tensor save_buffer() const

Serialize and save a TensorBlock to an in-memory buffer (represented as a torch::Tensor of bytes)

Public Static Functions

static TorchTensorBlock gradient(TorchTensorBlock self, const std::string &parameter)

Get a gradient from this TensorBlock.

static std::vector<std::tuple<std::string, TorchTensorBlock>> gradients(TorchTensorBlock self)

Get a all gradients and associated parameters in this block.

static TorchTensorBlock load(const std::string &path)

Load a serialized TensorBlock from the given path.

static TorchTensorBlock load_buffer(torch::Tensor buffer)

Load a serialized TensorBlock from an in-memory buffer (represented as a torch::Tensor of bytes)