1# Owner(s): ["module: onnx"] 2 3import torch 4 5 6# Autograd funtion that is a replica of the autograd funtion in 7# test_utility_funs.py (test_autograd_module_name) 8class CustomFunction(torch.autograd.Function): 9 @staticmethod 10 def forward(ctx, input): 11 ctx.save_for_backward(input) 12 return input.clamp(min=0) 13 14 @staticmethod 15 def backward(ctx, grad_output): 16 (input,) = ctx.saved_tensors 17 grad_input = grad_output.clone() 18 grad_input[input < 0] = 0 19 return grad_input 20