Output gradient¶
- metatrain.utils.output_gradient.compute_gradient(target: Tensor, inputs: List[Tensor], is_training: bool) List[Tensor] [source]¶
Calculates the gradient of a target tensor with respect to a list of input tensors.
target
must be a single torch.Tensor object. If target contains multiple values, the gradient will be calculated with respect to the sum of all values.