xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/static_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["module: unknown"]
3
4import torch
5
6
7class StaticModule:
8    def __init__(self, scripted):
9        # this is an nn.Module
10        if hasattr(scripted, "_c"):
11            self.static_module = torch._C._jit_to_static_module(scripted._c)
12        else:
13            self.static_module = torch._C._jit_to_static_module(scripted.graph)
14
15    def __call__(self, *args, **kwargs):
16        return self.static_module(*args, **kwargs)
17
18    def benchmark(self, args, kwargs, warmup_runs, main_runs):
19        self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
20
21    def runAsync(self, args, kwargs):
22        return self.static_module.runAsync(args, kwargs)
23
24    def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
25        return self.static_module.benchmark_individual_ops(
26            args, kwargs, warmup_runs, main_runs
27        )
28