# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import unittest import executorch.exir as exir import executorch.exir.tests.models as models import torch from parameterized import parameterized class TestCapture(unittest.TestCase): # pyre-ignore @parameterized.expand(models.MODELS) def test_module_call(self, model_name: str, model: torch.nn.Module) -> None: # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. inputs = model.get_random_inputs() expected = model(*inputs) # TODO(ycao): Replace it with capture_multiple exported_program = exir.capture(model, inputs, exir.CaptureConfig()) self.assertTrue(torch.allclose(expected, exported_program(*inputs)))