1import yaml 2 3import torch 4 5 6class SumMod(torch.nn.Module): 7 def forward(self, inp): 8 return torch.sum(inp) 9