xref: /aosp_15_r20/external/pytorch/ios/TestApp/TestAppTests/TestLiteInterpreter.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <XCTest/XCTest.h>
2
3#include <torch/csrc/jit/mobile/import.h>
4#include <torch/csrc/jit/mobile/module.h>
5#include <torch/script.h>
6
7@interface TestAppTests : XCTestCase
8
9@end
10
11@implementation TestAppTests {
12}
13
14- (void)testCoreML {
15  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml"
16                                                                         ofType:@"ptl"];
17  auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
18  c10::InferenceMode mode;
19  auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
20  auto outputTensor = module.forward({input}).toTensor();
21  XCTAssertTrue(outputTensor.numel() == 1000);
22}
23
24- (void)testModel:(NSString*)modelName {
25  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:modelName
26                                                                         ofType:@"ptl"];
27  XCTAssertNotNil(modelPath, @"Model not found. See https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test.");
28  [self runModel:modelPath];
29
30  // model generated on the fly
31  NSString* onTheFlyModelName = [NSString stringWithFormat:@"%@", modelName];
32  NSString* onTheFlyModelPath = [[NSBundle bundleForClass:[self class]] pathForResource:onTheFlyModelName
33                                                                         ofType:@"ptl"];
34  XCTAssertNotNil(onTheFlyModelPath, @"On-the-fly model not found. Follow https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test to generate them and run the setup.rb script again.");
35  [self runModel:onTheFlyModelPath];
36}
37
38- (void)runModel:(NSString*)modelPath {
39  c10::InferenceMode mode;
40  auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
41  auto has_bundled_input = module.find_method("get_all_bundled_inputs");
42  if (has_bundled_input) {
43    c10::IValue bundled_inputs = module.run_method("get_all_bundled_inputs");
44    c10::List<at::IValue> all_inputs = bundled_inputs.toList();
45    std::vector<std::vector<at::IValue>> inputs;
46    for (at::IValue input : all_inputs) {
47      inputs.push_back(input.toTupleRef().elements());
48    }
49    // run with the first bundled input
50    XCTAssertNoThrow(module.forward(inputs[0]));
51  } else {
52    XCTAssertNoThrow(module.forward({}));
53  }
54}
55
56// TODO remove this once updated test script
57- (void)testLiteInterpreter {
58  XCTAssertTrue(true);
59}
60
61- (void)testMobileNetV2 {
62  [self testModel:@"mobilenet_v2"];
63}
64
65- (void)testPointwiseOps {
66  [self testModel:@"pointwise_ops"];
67}
68
69- (void)testReductionOps {
70  [self testModel:@"reduction_ops"];
71}
72
73- (void)testComparisonOps {
74  [self testModel:@"comparison_ops"];
75}
76
77- (void)testOtherMathOps {
78  [self testModel:@"other_math_ops"];
79}
80
81- (void)testSpectralOps {
82  [self testModel:@"spectral_ops"];
83}
84
85- (void)testBlasLapackOps {
86  [self testModel:@"blas_lapack_ops"];
87}
88
89- (void)testSamplingOps {
90  [self testModel:@"sampling_ops"];
91}
92
93- (void)testTensorOps {
94  [self testModel:@"tensor_general_ops"];
95}
96
97- (void)testTensorCreationOps {
98  [self testModel:@"tensor_creation_ops"];
99}
100
101- (void)testTensorIndexingOps {
102  [self testModel:@"tensor_indexing_ops"];
103}
104
105- (void)testTensorTypingOps {
106  [self testModel:@"tensor_typing_ops"];
107}
108
109- (void)testTensorViewOps {
110  [self testModel:@"tensor_view_ops"];
111}
112
113- (void)testConvolutionOps {
114  [self testModel:@"convolution_ops"];
115}
116
117- (void)testPoolingOps {
118  [self testModel:@"pooling_ops"];
119}
120
121- (void)testPaddingOps {
122  [self testModel:@"padding_ops"];
123}
124
125- (void)testActivationOps {
126  [self testModel:@"activation_ops"];
127}
128
129- (void)testNormalizationOps {
130  [self testModel:@"normalization_ops"];
131}
132
133- (void)testRecurrentOps {
134  [self testModel:@"recurrent_ops"];
135}
136
137- (void)testTransformerOps {
138  [self testModel:@"transformer_ops"];
139}
140
141- (void)testLinearOps {
142  [self testModel:@"linear_ops"];
143}
144
145- (void)testDropoutOps {
146  [self testModel:@"dropout_ops"];
147}
148
149- (void)testSparseOps {
150  [self testModel:@"sparse_ops"];
151}
152
153- (void)testDistanceFunctionOps {
154  [self testModel:@"distance_function_ops"];
155}
156
157- (void)testLossFunctionOps {
158  [self testModel:@"loss_function_ops"];
159}
160
161- (void)testVisionFunctionOps {
162  [self testModel:@"vision_function_ops"];
163}
164
165- (void)testShuffleOps {
166  [self testModel:@"shuffle_ops"];
167}
168
169- (void)testNNUtilsOps {
170  [self testModel:@"nn_utils_ops"];
171}
172
173- (void)testQuantOps {
174  [self testModel:@"general_quant_ops"];
175}
176
177- (void)testDynamicQuantOps {
178  [self testModel:@"dynamic_quant_ops"];
179}
180
181- (void)testStaticQuantOps {
182  [self testModel:@"static_quant_ops"];
183}
184
185- (void)testFusedQuantOps {
186  [self testModel:@"fused_quant_ops"];
187}
188
189- (void)testTorchScriptBuiltinQuantOps {
190  [self testModel:@"torchscript_builtin_ops"];
191}
192
193- (void)testTorchScriptCollectionQuantOps {
194  [self testModel:@"torchscript_collection_ops"];
195}
196
197@end
198