Output gradient¶
- metatrain.utils.output_gradient.compute_gradient(target: Tensor, inputs: List[Tensor], is_training: bool) List[Tensor] [source]¶
Calculates the gradient of a target tensor with respect to a list of input tensors.
target
must be a single torch.Tensor object. If target contains multiple values, the gradient will be calculated with respect to the sum of all values.- Parameters:
target (Tensor) – The tensor for which the gradient is to be computed.
inputs (List[Tensor]) – A list of tensors with respect to which the gradient is computed.
is_training (bool) – A boolean indicating whether the model is in training mode. If True, the computation graph is retained for further gradient computations. If False, the graph is not retained, which saves memory.
- Returns:
A list of tensors representing the gradients of the target with respect to each input
- Return type: