TensorBlock#
-
using metatensor_torch::TorchTensorBlock = torch::intrusive_ptr<TensorBlockHolder>#
TorchScript will always manipulate
TensorBlockHolder
through atorch::intrusive_ptr
-
class TensorBlockHolder : public CustomClassHolder#
Wrapper around
metatensor::TensorBlock
for integration with TorchScriptPython/TorchScript code will typically manipulate
torch::intrusive_ptr<TensorBlockHolder>
(i.e.TorchTensorBlock
) instead of instances ofTensorBlockHolder
.Public Functions
-
TensorBlockHolder(torch::Tensor data, TorchLabels samples, std::vector<TorchLabels> components, TorchLabels properties)#
Create a new TensorBlockHolder with the given data and metadata.
-
TensorBlockHolder(metatensor::TensorBlock block, torch::IValue parent)#
Create a torch TensorBlockHolder from a pre-existing
metatensor::TensorBlock
.If the block is a view inside another
TorchTensorBlock
orTorchTensorMap
, thenparent
should point to the corresponding object, making sure a reference to it is kept around.
-
TorchTensorBlock copy() const#
Make a copy of this
TensorBlockHolder
, including all the data contained inside
-
TorchTensorBlock to(torch::Device device)#
Return a new TorchTensorBlock where all blocks and relative labels are on the requested
device
.
-
torch::Tensor values()#
Get a view in the values in this block.
-
TorchLabels labels(uintptr_t axis) const#
Get the labels in this block associated with either
"values"
or one gradient (by settingvalues_gradients
to the gradient parameter); in the givenaxis
.
-
inline TorchLabels samples() const#
Access the sample
Labels
for this block.The entries in these labels describe the first dimension of the
values()
array.
-
inline std::vector<TorchLabels> components() const#
Access the component
Labels
for this block.The entries in these labels describe intermediate dimensions of the
values()
array.
-
inline TorchLabels properties() const#
Access the property
Labels
for this block.The entries in these labels describe the last dimension of the
values()
array. The properties are guaranteed to be the same for values and gradients in the same block.
-
void add_gradient(const std::string ¶meter, TorchTensorBlock gradient)#
Add a set of gradients with respect to
parameters
in this block.- Parameters:
parameter – add gradients with respect to this
parameter
(e.g."positions"
,"cell"
, …)gradient – a
TorchTensorBlock
whose values contain the gradients with respect to theparameter
. The labels of the gradientTorchTensorBlock
should be organized as follows: itssamples
must contain"sample"
as the first label, which establishes a correspondence with thesamples
of the originalTorchTensorBlock
; its components must contain at least the same components as the originalTorchTensorBlock
, with any additional component coming before those; its properties must match those of the originalTorchTensorBlock
.
-
inline std::vector<std::string> gradients_list() const#
Get a list of all gradients defined in this block.
-
bool has_gradient(const std::string ¶meter) const#
Check if a given gradient is defined in this TensorBlock.
-
std::string repr() const#
Implementation of repr/__str__ for Python.
-
inline const metatensor::TensorBlock &as_metatensor() const#
Get the underlying metatensor TensorBlock.
Public Static Functions
-
static TorchTensorBlock gradient(TorchTensorBlock self, const std::string ¶meter)#
Get a gradient from this TensorBlock.
-
static std::vector<std::tuple<std::string, TorchTensorBlock>> gradients(TorchTensorBlock self)#
Get a all gradients and associated parameters in this block.
-
TensorBlockHolder(torch::Tensor data, TorchLabels samples, std::vector<TorchLabels> components, TorchLabels properties)#