xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_modeloption.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3import pytest
4
5from pyarmnn import BackendOptions, BackendOption, BackendId, OptimizerOptions, ShapeInferenceMethod_InferAndValidate
6
7
8@pytest.mark.parametrize("data", (True, -100, 128, 0.12345, 'string'))
9def test_backend_option_ctor(data):
10    bo = BackendOption("name", data)
11    assert "name" == bo.GetName()
12
13
14def test_backend_options_ctor():
15    backend_id = BackendId('a')
16    bos = BackendOptions(backend_id)
17
18    assert 'a' == str(bos.GetBackendId())
19
20    another_bos = BackendOptions(bos)
21    assert 'a' == str(another_bos.GetBackendId())
22
23
24def test_backend_options_add():
25    backend_id = BackendId('a')
26    bos = BackendOptions(backend_id)
27    bo = BackendOption("name", 1)
28    bos.AddOption(bo)
29
30    assert 1 == bos.GetOptionCount()
31    assert 1 == len(bos)
32
33    assert 'name' == bos[0].GetName()
34    assert 'name' == bos.GetOption(0).GetName()
35    for option in bos:
36        assert 'name' == option.GetName()
37
38    bos.AddOption(BackendOption("name2", 2))
39
40    assert 2 == bos.GetOptionCount()
41    assert 2 == len(bos)
42
43
44def test_backend_option_ownership():
45    backend_id = BackendId('b')
46    bos = BackendOptions(backend_id)
47    bo = BackendOption('option', True)
48    bos.AddOption(bo)
49
50    assert bo.thisown
51
52    del bo
53
54    assert 1 == bos.GetOptionCount()
55    option = bos[0]
56    assert not option.thisown
57    assert 'option' == option.GetName()
58
59    del option
60
61    option_again = bos[0]
62    assert not option_again.thisown
63    assert 'option' == option_again.GetName()
64
65
66def test_optimizer_options_with_model_opt():
67    a = BackendOptions(BackendId('a'))
68
69    oo = OptimizerOptions(True,
70                          False,
71                          False,
72                          ShapeInferenceMethod_InferAndValidate,
73                          True,
74                          [a],
75                          True)
76
77    mo = oo.m_ModelOptions
78
79    assert 1 == len(mo)
80    assert 'a' == str(mo[0].GetBackendId())
81
82    b = BackendOptions(BackendId('b'))
83
84    c = BackendOptions(BackendId('c'))
85
86    oo.m_ModelOptions = (a, b, c)
87
88    mo = oo.m_ModelOptions
89
90    assert 3 == len(oo.m_ModelOptions)
91
92    assert 'a' == str(mo[0].GetBackendId())
93    assert 'b' == str(mo[1].GetBackendId())
94    assert 'c' == str(mo[2].GetBackendId())
95
96
97def test_optimizer_option_default():
98    oo = OptimizerOptions(True,
99                          False,
100                          False,
101                          ShapeInferenceMethod_InferAndValidate,
102                          True)
103
104    assert 0 == len(oo.m_ModelOptions)
105
106
107def test_optimizer_options_fail():
108    a = BackendOptions(BackendId('a'))
109
110    with pytest.raises(TypeError) as err:
111        OptimizerOptions(True,
112                         False,
113                         False,
114                         ShapeInferenceMethod_InferAndValidate,
115                         True,
116                         a,
117                         True)
118
119    assert "Wrong number or type of arguments" in str(err.value)
120
121    with pytest.raises(TypeError) as err:
122        oo = OptimizerOptions(True,
123                              False,
124                              False,
125                              ShapeInferenceMethod_InferAndValidate,
126                              True)
127
128        oo.m_ModelOptions = 'nonsense'
129
130    assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value)
131
132    with pytest.raises(TypeError) as err:
133        oo = OptimizerOptions(True,
134                              False,
135                              False,
136                              ShapeInferenceMethod_InferAndValidate,
137                              True)
138
139        oo.m_ModelOptions = ['nonsense', a]
140
141    assert "in method 'OptimizerOptions_m_ModelOptions_set', argument 2" in str(err.value)
142