TensorBlock#
-
using metatensor_torch::TorchTensorBlock = torch::intrusive_ptr<TensorBlockHolder>#
TorchScript will always manipulate
TensorBlockHolder
through atorch::intrusive_ptr
-
class TensorBlockHolder : public CustomClassHolder#
Wrapper around
metatensor::TensorBlock
for integration with TorchScriptPython/TorchScript code will typically manipulate
torch::intrusive_ptr<TensorBlockHolder>
(i.e.TorchTensorBlock
) instead of instances ofTensorBlockHolder
.Public Functions
-
TensorBlockHolder(torch::Tensor data, TorchLabels samples, std::vector<TorchLabels> components, TorchLabels properties)#
Create a new TensorBlockHolder with the given data and metadata.
-
TensorBlockHolder(metatensor::TensorBlock block, torch::IValue parent)#
Create a torch TensorBlockHolder from a pre-existing
metatensor::TensorBlock
.If the block is a view inside another
TorchTensorBlock
orTorchTensorMap
, thenparent
should point to the corresponding object, making sure a reference to it is kept around.
-
TorchTensorBlock copy() const#
Make a copy of this
TensorBlockHolder
, including all the data contained inside
-
torch::Tensor values() const#
Get a view in the values in this block.
-
TorchLabels labels(uintptr_t axis) const#
Get the labels in this block associated with either
"values"
or one gradient (by settingvalues_gradients
to the gradient parameter); in the givenaxis
.
-
inline TorchLabels samples() const#
Access the sample
Labels
for this block.The entries in these labels describe the first dimension of the
values()
array.
-
inline std::vector<TorchLabels> components() const#
Access the component
Labels
for this block.The entries in these labels describe intermediate dimensions of the
values()
array.
-
inline TorchLabels properties() const#
Access the property
Labels
for this block.The entries in these labels describe the last dimension of the
values()
array. The properties are guaranteed to be the same for values and gradients in the same block.
-
void add_gradient(const std::string ¶meter, TorchTensorBlock gradient)#
Add a set of gradients with respect to
parameters
in this block.- Parameters:
parameter – add gradients with respect to this
parameter
(e.g."positions"
,"cell"
, …)gradient – a
TorchTensorBlock
whose values contain the gradients with respect to theparameter
. The labels of the gradientTorchTensorBlock
should be organized as follows: itssamples
must contain"sample"
as the first label, which establishes a correspondence with thesamples
of the originalTorchTensorBlock
; its components must contain at least the same components as the originalTorchTensorBlock
, with any additional component coming before those; its properties must match those of the originalTorchTensorBlock
.
-
inline std::vector<std::string> gradients_list() const#
Get a list of all gradients defined in this block.
-
bool has_gradient(const std::string ¶meter) const#
Check if a given gradient is defined in this TensorBlock.
-
inline torch::Device device() const#
Get the device for the values stored in this
TensorBlock
-
inline torch::Dtype scalar_type() const#
Get the dtype for the values stored in this
TensorBlock
-
TorchTensorBlock to(torch::optional<torch::Dtype> dtype = torch::nullopt, torch::optional<torch::Device> device = torch::nullopt) const#
Move all arrays in this block to the given
dtype
anddevice
.
-
TorchTensorBlock to_positional(torch::IValue positional_1, torch::IValue positional_2, torch::optional<torch::Dtype> dtype, torch::optional<torch::Device> device, torch::optional<std::string> arrays) const#
Wrapper of the
to
function to enable using it with positional parameters from Python; for exampleto(dtype)
,to(device)
,to(dtype, device=device)
,to(dtype, device)
,to(device, dtype)
, etc.arrays
is left as a keyword argument since it is mainly here for compatibility with the pure Python backend, and only"torch"
is supported.
-
std::string repr() const#
Implementation of repr/__str__ for Python.
-
inline const metatensor::TensorBlock &as_metatensor() const#
Get the underlying metatensor TensorBlock.
Public Static Functions
-
static TorchTensorBlock gradient(TorchTensorBlock self, const std::string ¶meter)#
Get a gradient from this TensorBlock.
-
static std::vector<std::tuple<std::string, TorchTensorBlock>> gradients(TorchTensorBlock self)#
Get a all gradients and associated parameters in this block.
-
TensorBlockHolder(torch::Tensor data, TorchLabels samples, std::vector<TorchLabels> components, TorchLabels properties)#