1#import <XCTest/XCTest.h> 2 3#include <torch/csrc/jit/mobile/import.h> 4#include <torch/csrc/jit/mobile/module.h> 5#include <torch/script.h> 6 7@interface TestAppTests : XCTestCase 8 9@end 10 11@implementation TestAppTests { 12} 13 14- (void)testCoreML { 15 NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml" 16 ofType:@"ptl"]; 17 auto module = torch::jit::_load_for_mobile(modelPath.UTF8String); 18 c10::InferenceMode mode; 19 auto input = torch::ones({1, 3, 224, 224}, at::kFloat); 20 auto outputTensor = module.forward({input}).toTensor(); 21 XCTAssertTrue(outputTensor.numel() == 1000); 22} 23 24- (void)testModel:(NSString*)modelName { 25 NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:modelName 26 ofType:@"ptl"]; 27 XCTAssertNotNil(modelPath, @"Model not found. See https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test."); 28 [self runModel:modelPath]; 29 30 // model generated on the fly 31 NSString* onTheFlyModelName = [NSString stringWithFormat:@"%@", modelName]; 32 NSString* onTheFlyModelPath = [[NSBundle bundleForClass:[self class]] pathForResource:onTheFlyModelName 33 ofType:@"ptl"]; 34 XCTAssertNotNil(onTheFlyModelPath, @"On-the-fly model not found. Follow https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test to generate them and run the setup.rb script again."); 35 [self runModel:onTheFlyModelPath]; 36} 37 38- (void)runModel:(NSString*)modelPath { 39 c10::InferenceMode mode; 40 auto module = torch::jit::_load_for_mobile(modelPath.UTF8String); 41 auto has_bundled_input = module.find_method("get_all_bundled_inputs"); 42 if (has_bundled_input) { 43 c10::IValue bundled_inputs = module.run_method("get_all_bundled_inputs"); 44 c10::List<at::IValue> all_inputs = bundled_inputs.toList(); 45 std::vector<std::vector<at::IValue>> inputs; 46 for (at::IValue input : all_inputs) { 47 inputs.push_back(input.toTupleRef().elements()); 48 } 49 // run with the first bundled input 50 XCTAssertNoThrow(module.forward(inputs[0])); 51 } else { 52 XCTAssertNoThrow(module.forward({})); 53 } 54} 55 56// TODO remove this once updated test script 57- (void)testLiteInterpreter { 58 XCTAssertTrue(true); 59} 60 61- (void)testMobileNetV2 { 62 [self testModel:@"mobilenet_v2"]; 63} 64 65- (void)testPointwiseOps { 66 [self testModel:@"pointwise_ops"]; 67} 68 69- (void)testReductionOps { 70 [self testModel:@"reduction_ops"]; 71} 72 73- (void)testComparisonOps { 74 [self testModel:@"comparison_ops"]; 75} 76 77- (void)testOtherMathOps { 78 [self testModel:@"other_math_ops"]; 79} 80 81- (void)testSpectralOps { 82 [self testModel:@"spectral_ops"]; 83} 84 85- (void)testBlasLapackOps { 86 [self testModel:@"blas_lapack_ops"]; 87} 88 89- (void)testSamplingOps { 90 [self testModel:@"sampling_ops"]; 91} 92 93- (void)testTensorOps { 94 [self testModel:@"tensor_general_ops"]; 95} 96 97- (void)testTensorCreationOps { 98 [self testModel:@"tensor_creation_ops"]; 99} 100 101- (void)testTensorIndexingOps { 102 [self testModel:@"tensor_indexing_ops"]; 103} 104 105- (void)testTensorTypingOps { 106 [self testModel:@"tensor_typing_ops"]; 107} 108 109- (void)testTensorViewOps { 110 [self testModel:@"tensor_view_ops"]; 111} 112 113- (void)testConvolutionOps { 114 [self testModel:@"convolution_ops"]; 115} 116 117- (void)testPoolingOps { 118 [self testModel:@"pooling_ops"]; 119} 120 121- (void)testPaddingOps { 122 [self testModel:@"padding_ops"]; 123} 124 125- (void)testActivationOps { 126 [self testModel:@"activation_ops"]; 127} 128 129- (void)testNormalizationOps { 130 [self testModel:@"normalization_ops"]; 131} 132 133- (void)testRecurrentOps { 134 [self testModel:@"recurrent_ops"]; 135} 136 137- (void)testTransformerOps { 138 [self testModel:@"transformer_ops"]; 139} 140 141- (void)testLinearOps { 142 [self testModel:@"linear_ops"]; 143} 144 145- (void)testDropoutOps { 146 [self testModel:@"dropout_ops"]; 147} 148 149- (void)testSparseOps { 150 [self testModel:@"sparse_ops"]; 151} 152 153- (void)testDistanceFunctionOps { 154 [self testModel:@"distance_function_ops"]; 155} 156 157- (void)testLossFunctionOps { 158 [self testModel:@"loss_function_ops"]; 159} 160 161- (void)testVisionFunctionOps { 162 [self testModel:@"vision_function_ops"]; 163} 164 165- (void)testShuffleOps { 166 [self testModel:@"shuffle_ops"]; 167} 168 169- (void)testNNUtilsOps { 170 [self testModel:@"nn_utils_ops"]; 171} 172 173- (void)testQuantOps { 174 [self testModel:@"general_quant_ops"]; 175} 176 177- (void)testDynamicQuantOps { 178 [self testModel:@"dynamic_quant_ops"]; 179} 180 181- (void)testStaticQuantOps { 182 [self testModel:@"static_quant_ops"]; 183} 184 185- (void)testFusedQuantOps { 186 [self testModel:@"fused_quant_ops"]; 187} 188 189- (void)testTorchScriptBuiltinQuantOps { 190 [self testModel:@"torchscript_builtin_ops"]; 191} 192 193- (void)testTorchScriptCollectionQuantOps { 194 [self testModel:@"torchscript_collection_ops"]; 195} 196 197@end 198