1 // 2 // Copyright © 2021 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <armnn/utility/Assert.hpp> 7 8 #include <cl/ClImportTensorHandleFactory.hpp> 9 10 #include <doctest/doctest.h> 11 12 TEST_SUITE("ClImportTensorHandleFactoryTests") 13 { 14 using namespace armnn; 15 16 TEST_CASE("ImportTensorFactoryAskedToCreateManagedTensorThrowsException") 17 { 18 // Create the factory to import tensors. 19 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 20 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 21 TensorInfo tensorInfo; 22 // This factory is designed to import the memory of tensors. Asking for a handle that requires 23 // a memory manager should result in an exception. 24 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException); 25 REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException); 26 } 27 28 TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle") 29 { 30 // Create the factory to import tensors. 31 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 32 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 33 TensorShape tensorShape{ 6, 7, 8, 9 }; 34 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32); 35 // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is 36 // passed through correctly. 37 auto tensorHandle = factory.CreateTensorHandle(tensorInfo); 38 ARMNN_ASSERT(tensorHandle); 39 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); 40 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); 41 42 // Same method but explicitly specifying isManaged = false. 43 tensorHandle = factory.CreateTensorHandle(tensorInfo, false); 44 CHECK(tensorHandle); 45 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); 46 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); 47 48 // Now try TensorInfo and DataLayout factory method. 49 tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC); 50 CHECK(tensorHandle); 51 ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); 52 ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); 53 } 54 55 TEST_CASE("CreateSubtensorOfImportTensor") 56 { 57 // Create the factory to import tensors. 58 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 59 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 60 // Create a standard inport tensor. 61 TensorShape tensorShape{ 224, 224, 1, 1 }; 62 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32); 63 auto tensorHandle = factory.CreateTensorHandle(tensorInfo); 64 // Use the factory to create a 16x16 sub tensor. 65 TensorShape subTensorShape{ 16, 16, 1, 1 }; 66 // Starting at an offset of 1x1. 67 uint32_t origin[4] = { 1, 1, 0, 0 }; 68 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); 69 CHECK(subTensor); 70 ARMNN_ASSERT(subTensor->GetShape() == subTensorShape); 71 ARMNN_ASSERT(subTensor->GetParent() == tensorHandle.get()); 72 } 73 74 TEST_CASE("CreateSubtensorNonZeroXYIsInvalid") 75 { 76 // Create the factory to import tensors. 77 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 78 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 79 // Create a standard import tensor. 80 TensorShape tensorShape{ 224, 224, 1, 1 }; 81 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32); 82 auto tensorHandle = factory.CreateTensorHandle(tensorInfo); 83 // Use the factory to create a 16x16 sub tensor. 84 TensorShape subTensorShape{ 16, 16, 1, 1 }; 85 // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our 86 // check "(coords.x() != 0 || coords.y() != 0)" 87 uint32_t origin[4] = { 0, 0, 1, 1 }; 88 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); 89 // We expect a nullptr. 90 ARMNN_ASSERT(subTensor == nullptr); 91 } 92 93 TEST_CASE("CreateSubtensorXYMustMatchParent") 94 { 95 // Create the factory to import tensors. 96 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 97 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 98 // Create a standard import tensor. 99 TensorShape tensorShape{ 224, 224, 1, 1 }; 100 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32); 101 auto tensorHandle = factory.CreateTensorHandle(tensorInfo); 102 // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different. 103 TensorShape subTensorShape{ 16, 16, 2, 2 }; 104 // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y())) 105 uint32_t origin[4] = { 1, 1, 0, 0 }; 106 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); 107 // We expect a nullptr. 108 ARMNN_ASSERT(subTensor == nullptr); 109 } 110 111 TEST_CASE("CreateSubtensorMustBeSmallerThanParent") 112 { 113 // Create the factory to import tensors. 114 ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc), 115 static_cast<MemorySourceFlags>(MemorySource::Malloc)); 116 // Create a standard import tensor. 117 TensorShape tensorShape{ 224, 224, 1, 1 }; 118 TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32); 119 auto tensorHandle = factory.CreateTensorHandle(tensorInfo); 120 // Ask for a subtensor that's the same size as the parent. 121 TensorShape subTensorShape{ 224, 224, 1, 1 }; 122 uint32_t origin[4] = { 1, 1, 0, 0 }; 123 // This should result in a nullptr. 124 auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); 125 ARMNN_ASSERT(subTensor == nullptr); 126 } 127 128 } 129