xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_iconnectable.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import pytest
4
5import pyarmnn as ann
6
7
8@pytest.fixture(scope="function")
9def network():
10    return ann.INetwork()
11
12
13class TestIInputIOutputIConnectable:
14
15    def test_input_slot(self, network):
16        # Create input, addition & output layer
17        input1 = network.AddInputLayer(0, "input1")
18        input2 = network.AddInputLayer(1, "input2")
19        add = network.AddAdditionLayer("addition")
20        output = network.AddOutputLayer(0, "output")
21
22        # Connect the input/output slots for each layer
23        input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
24        input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
25        add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
26
27        # Check IInputSlot GetConnection()
28        input_slot = add.GetInputSlot(0)
29        input_slot_connection = input_slot.GetConnection()
30
31        assert isinstance(input_slot_connection, ann.IOutputSlot)
32
33        del input_slot_connection
34
35        assert input_slot.GetConnection()
36        assert isinstance(input_slot.GetConnection(), ann.IOutputSlot)
37
38        del input_slot
39
40        assert add.GetInputSlot(0)
41
42    def test_output_slot(self, network):
43
44        # Create input, addition & output layer
45        input1 = network.AddInputLayer(0, "input1")
46        input2 = network.AddInputLayer(1, "input2")
47        add = network.AddAdditionLayer("addition")
48        output = network.AddOutputLayer(0, "output")
49
50        # Connect the input/output slots for each layer
51        input1.GetOutputSlot(0).Connect(add.GetInputSlot(0))
52        input2.GetOutputSlot(0).Connect(add.GetInputSlot(1))
53        add.GetOutputSlot(0).Connect(output.GetInputSlot(0))
54
55        # Check IInputSlot GetConnection()
56        add_get_input_connection = add.GetInputSlot(0).GetConnection()
57        output_get_input_connection = output.GetInputSlot(0).GetConnection()
58
59        # Check IOutputSlot GetConnection()
60        add_get_output_connect = add.GetOutputSlot(0).GetConnection(0)
61        assert isinstance(add_get_output_connect.GetConnection(), ann.IOutputSlot)
62
63        # Test IOutputSlot GetNumConnections() & CalculateIndexOnOwner()
64        assert add_get_input_connection.GetNumConnections() == 1
65        assert len(add_get_input_connection) == 1
66        assert add_get_input_connection[0]
67        assert add_get_input_connection.CalculateIndexOnOwner() == 0
68
69        # Check GetOwningLayerGuid(). Check that it is different for add and output layer
70        assert add_get_input_connection.GetOwningLayerGuid() != output_get_input_connection.GetOwningLayerGuid()
71
72        # Set TensorInfo
73        test_tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), ann.DataType_Float32)
74
75        # Check IsTensorInfoSet()
76        assert not add_get_input_connection.IsTensorInfoSet()
77        add_get_input_connection.SetTensorInfo(test_tensor_info)
78        assert add_get_input_connection.IsTensorInfoSet()
79
80        # Check GetTensorInfo()
81        output_tensor_info = add_get_input_connection.GetTensorInfo()
82        assert 2 == output_tensor_info.GetNumDimensions()
83        assert 6 == output_tensor_info.GetNumElements()
84
85        # Check Disconnect()
86        assert output_get_input_connection.GetNumConnections() == 1  # 1 connection to Outputslot0 from input1
87        add.GetOutputSlot(0).Disconnect(output.GetInputSlot(0))  # disconnect add.OutputSlot0 from Output.InputSlot0
88        assert output_get_input_connection.GetNumConnections() == 0
89
90    def test_output_slot__out_of_range(self, network):
91        # Create input layer to check output slot get item handling
92        input1 = network.AddInputLayer(0, "input1")
93
94        outputSlot = input1.GetOutputSlot(0)
95        with pytest.raises(ValueError) as err:
96                outputSlot[1]
97
98        assert "Invalid index 1 provided" in str(err.value)
99
100    def test_iconnectable_guid(self, network):
101
102        # Check IConnectable GetGuid()
103        # Note Guid can change based on which tests are run so
104        # checking here that each layer does not have the same guid
105        add_id = network.AddAdditionLayer().GetGuid()
106        output_id = network.AddOutputLayer(0).GetGuid()
107        assert add_id != output_id
108
109    def test_iconnectable_layer_functions(self, network):
110
111        # Create input, addition & output layer
112        input1 = network.AddInputLayer(0, "input1")
113        input2 = network.AddInputLayer(1, "input2")
114        add = network.AddAdditionLayer("addition")
115        output = network.AddOutputLayer(0, "output")
116
117        # Check GetNumInputSlots(), GetName() & GetNumOutputSlots()
118        assert input1.GetNumInputSlots() == 0
119        assert input1.GetName() == "input1"
120        assert input1.GetNumOutputSlots() == 1
121
122        assert input2.GetNumInputSlots() == 0
123        assert input2.GetName() == "input2"
124        assert input2.GetNumOutputSlots() == 1
125
126        assert add.GetNumInputSlots() == 2
127        assert add.GetName() == "addition"
128        assert add.GetNumOutputSlots() == 1
129
130        assert output.GetNumInputSlots() == 1
131        assert output.GetName() == "output"
132        assert output.GetNumOutputSlots() == 0
133
134        # Check GetOutputSlot()
135        input1_get_output = input1.GetOutputSlot(0)
136        assert input1_get_output.GetNumConnections() == 0
137        assert len(input1_get_output) == 0
138
139        # Check GetInputSlot()
140        add_get_input = add.GetInputSlot(0)
141        add_get_input.GetConnection()
142        assert isinstance(add_get_input, ann.IInputSlot)
143