1# Copyright © 2020 Arm Ltd. All rights reserved. 2# SPDX-License-Identifier: MIT 3import pytest 4 5import pyarmnn as ann 6 7 8@pytest.fixture(scope="function") 9def network(): 10 return ann.INetwork() 11 12 13class TestIInputIOutputIConnectable: 14 15 def test_input_slot(self, network): 16 # Create input, addition & output layer 17 input1 = network.AddInputLayer(0, "input1") 18 input2 = network.AddInputLayer(1, "input2") 19 add = network.AddAdditionLayer("addition") 20 output = network.AddOutputLayer(0, "output") 21 22 # Connect the input/output slots for each layer 23 input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) 24 input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) 25 add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) 26 27 # Check IInputSlot GetConnection() 28 input_slot = add.GetInputSlot(0) 29 input_slot_connection = input_slot.GetConnection() 30 31 assert isinstance(input_slot_connection, ann.IOutputSlot) 32 33 del input_slot_connection 34 35 assert input_slot.GetConnection() 36 assert isinstance(input_slot.GetConnection(), ann.IOutputSlot) 37 38 del input_slot 39 40 assert add.GetInputSlot(0) 41 42 def test_output_slot(self, network): 43 44 # Create input, addition & output layer 45 input1 = network.AddInputLayer(0, "input1") 46 input2 = network.AddInputLayer(1, "input2") 47 add = network.AddAdditionLayer("addition") 48 output = network.AddOutputLayer(0, "output") 49 50 # Connect the input/output slots for each layer 51 input1.GetOutputSlot(0).Connect(add.GetInputSlot(0)) 52 input2.GetOutputSlot(0).Connect(add.GetInputSlot(1)) 53 add.GetOutputSlot(0).Connect(output.GetInputSlot(0)) 54 55 # Check IInputSlot GetConnection() 56 add_get_input_connection = add.GetInputSlot(0).GetConnection() 57 output_get_input_connection = output.GetInputSlot(0).GetConnection() 58 59 # Check IOutputSlot GetConnection() 60 add_get_output_connect = add.GetOutputSlot(0).GetConnection(0) 61 assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot) 62 63 # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner() 64 assert add_get_input_connection.GetNumConnections() == 1 65 assert len(add_get_input_connection) == 1 66 assert add_get_input_connection[0] 67 assert add_get_input_connection.CalculateIndexOnOwner() == 0 68 69 # Check GetOwningLayerGuid(). Check that it is different for add and output layer 70 assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid() 71 72 # Set TensorInfo 73 test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32) 74 75 # Check IsTensorInfoSet() 76 assert not add_get_input_connection.IsTensorInfoSet() 77 add_get_input_connection.SetTensorInfo(test_tensor_info) 78 assert add_get_input_connection.IsTensorInfoSet() 79 80 # Check GetTensorInfo() 81 output_tensor_info = add_get_input_connection.GetTensorInfo() 82 assert 2 == output_tensor_info.GetNumDimensions() 83 assert 6 == output_tensor_info.GetNumElements() 84 85 # Check Disconnect() 86 assert output_get_input_connection.GetNumConnections() == 1 # 1 connection to Outputslot0 from input1 87 add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0)) # disconnect add.OutputSlot0 from Output.InputSlot0 88 assert output_get_input_connection.GetNumConnections() == 0 89 90 def test_output_slot__out_of_range(self, network): 91 # Create input layer to check output slot get item handling 92 input1 = network.AddInputLayer(0, "input1") 93 94 outputSlot = input1.GetOutputSlot(0) 95 with pytest.raises(ValueError) as err: 96 outputSlot[1] 97 98 assert "Invalid index 1 provided" in str(err.value) 99 100 def test_iconnectable_guid(self, network): 101 102 # Check IConnectable GetGuid() 103 # Note Guid can change based on which tests are run so 104 # checking here that each layer does not have the same guid 105 add_id = network.AddAdditionLayer().GetGuid() 106 output_id = network.AddOutputLayer(0).GetGuid() 107 assert add_id != output_id 108 109 def test_iconnectable_layer_functions(self, network): 110 111 # Create input, addition & output layer 112 input1 = network.AddInputLayer(0, "input1") 113 input2 = network.AddInputLayer(1, "input2") 114 add = network.AddAdditionLayer("addition") 115 output = network.AddOutputLayer(0, "output") 116 117 # Check GetNumInputSlots(), GetName() & GetNumOutputSlots() 118 assert input1.GetNumInputSlots() == 0 119 assert input1.GetName() == "input1" 120 assert input1.GetNumOutputSlots() == 1 121 122 assert input2.GetNumInputSlots() == 0 123 assert input2.GetName() == "input2" 124 assert input2.GetNumOutputSlots() == 1 125 126 assert add.GetNumInputSlots() == 2 127 assert add.GetName() == "addition" 128 assert add.GetNumOutputSlots() == 1 129 130 assert output.GetNumInputSlots() == 1 131 assert output.GetName() == "output" 132 assert output.GetNumOutputSlots() == 0 133 134 # Check GetOutputSlot() 135 input1_get_output = input1.GetOutputSlot(0) 136 assert input1_get_output.GetNumConnections() == 0 137 assert len(input1_get_output) == 0 138 139 # Check GetInputSlot() 140 add_get_input = add.GetInputSlot(0) 141 add_get_input.GetConnection() 142 assert isinstance(add_get_input, ann.IInputSlot) 143