1# mypy: allow-untyped-defs 2# Owner(s): ["module: unknown"] 3 4import torch 5 6 7class StaticModule: 8 def __init__(self, scripted): 9 # this is an nn.Module 10 if hasattr(scripted, "_c"): 11 self.static_module = torch._C._jit_to_static_module(scripted._c) 12 else: 13 self.static_module = torch._C._jit_to_static_module(scripted.graph) 14 15 def __call__(self, *args, **kwargs): 16 return self.static_module(*args, **kwargs) 17 18 def benchmark(self, args, kwargs, warmup_runs, main_runs): 19 self.static_module.benchmark(args, kwargs, warmup_runs, main_runs) 20 21 def runAsync(self, args, kwargs): 22 return self.static_module.runAsync(args, kwargs) 23 24 def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs): 25 return self.static_module.benchmark_individual_ops( 26 args, kwargs, warmup_runs, main_runs 27 ) 28