xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_tensor.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3from copy import copy
4
5import pytest
6import numpy as np
7import pyarmnn as ann
8
9
10def __get_tensor_info(dt):
11    tensor_info = ann.TensorInfo(ann.TensorShape((2, 3)), dt)
12
13    return tensor_info
14
15
16@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
17                                ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
18                                ann.DataType_QAsymmS8])
19def test_create_tensor_with_info(dt):
20    tensor_info = __get_tensor_info(dt)
21    elements = tensor_info.GetNumElements()
22    num_bytes = tensor_info.GetNumBytes()
23    d_type = dt
24
25    tensor = ann.Tensor(tensor_info)
26
27    assert tensor_info != tensor.GetInfo(), "Different objects"
28    assert elements == tensor.GetNumElements()
29    assert num_bytes == tensor.GetNumBytes()
30    assert d_type == tensor.GetDataType()
31
32
33def test_create_tensor_undefined_datatype():
34    tensor_info = ann.TensorInfo()
35    tensor_info.SetDataType(99)
36
37    with pytest.raises(ValueError) as err:
38        ann.Tensor(tensor_info)
39
40    assert 'The data type provided for this Tensor is not supported.' in str(err.value)
41
42
43@pytest.mark.parametrize("dt", [ann.DataType_Float32])
44def test_tensor_memory_output(dt):
45    tensor_info = __get_tensor_info(dt)
46    tensor = ann.Tensor(tensor_info)
47
48    # empty memory area because inference has not yet been run.
49    assert tensor.get_memory_area().tolist()  # has random stuff
50    assert 4 == tensor.get_memory_area().itemsize, "it is float32"
51
52
53@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
54                                ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
55                                ann.DataType_QAsymmS8])
56def test_tensor__str__(dt):
57    tensor_info = __get_tensor_info(dt)
58    elements = tensor_info.GetNumElements()
59    num_bytes = tensor_info.GetNumBytes()
60    d_type = dt
61    dimensions = tensor_info.GetNumDimensions()
62
63    tensor = ann.Tensor(tensor_info)
64
65    assert str(tensor) == "Tensor{{DataType: {}, NumBytes: {}, NumDimensions: " \
66                               "{}, NumElements: {}}}".format(d_type, num_bytes, dimensions, elements)
67
68
69def test_create_empty_tensor():
70    tensor = ann.Tensor()
71
72    assert 0 == tensor.GetNumElements()
73    assert 0 == tensor.GetNumBytes()
74    assert tensor.get_memory_area() is None
75
76
77@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
78                                ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
79                                ann.DataType_QAsymmS8])
80def test_create_tensor_from_tensor(dt):
81    tensor_info = __get_tensor_info(dt)
82    tensor = ann.Tensor(tensor_info)
83    copied_tensor = ann.Tensor(tensor)
84
85    assert copied_tensor != tensor, "Different objects"
86    assert copied_tensor.GetInfo() != tensor.GetInfo(), "Different objects"
87    assert copied_tensor.get_memory_area().ctypes.data == tensor.get_memory_area().ctypes.data,  "Same memory area"
88    assert copied_tensor.GetNumElements() == tensor.GetNumElements()
89    assert copied_tensor.GetNumBytes() == tensor.GetNumBytes()
90    assert copied_tensor.GetDataType() == tensor.GetDataType()
91
92
93@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
94                                ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
95                                ann.DataType_QAsymmS8])
96def test_copy_tensor(dt):
97    tensor = ann.Tensor(__get_tensor_info(dt))
98    copied_tensor = copy(tensor)
99
100    assert copied_tensor != tensor, "Different objects"
101    assert copied_tensor.GetInfo() != tensor.GetInfo(), "Different objects"
102    assert copied_tensor.get_memory_area().ctypes.data == tensor.get_memory_area().ctypes.data,  "Same memory area"
103    assert copied_tensor.GetNumElements() == tensor.GetNumElements()
104    assert copied_tensor.GetNumBytes() == tensor.GetNumBytes()
105    assert copied_tensor.GetDataType() == tensor.GetDataType()
106
107
108@pytest.mark.parametrize("dt", [ann.DataType_Float32, ann.DataType_Float16,
109                                ann.DataType_QAsymmU8, ann.DataType_QSymmS8,
110                                ann.DataType_QAsymmS8])
111def test_copied_tensor_has_memory_area_access_after_deletion_of_original_tensor(dt):
112
113    tensor = ann.Tensor(__get_tensor_info(dt))
114
115    tensor.get_memory_area()[0] = 100
116
117    initial_mem_copy = np.array(tensor.get_memory_area())
118
119    assert 100 == initial_mem_copy[0]
120
121    copied_tensor = ann.Tensor(tensor)
122
123    del tensor
124    np.testing.assert_array_equal(copied_tensor.get_memory_area(), initial_mem_copy)
125    assert 100 == copied_tensor.get_memory_area()[0]
126
127
128def test_create_const_tensor_incorrect_args():
129    with pytest.raises(ValueError) as err:
130        ann.Tensor('something', 'something')
131
132    expected_error_message = "Incorrect number of arguments or type of arguments provided to create Tensor."
133    assert expected_error_message in str(err.value)
134
135
136@pytest.mark.parametrize("dt", [ann.DataType_Float16])
137def test_tensor_memory_output_fp16(dt):
138    # Check Tensor with float16
139    tensor_info = __get_tensor_info(dt)
140    tensor = ann.Tensor(tensor_info)
141
142    assert tensor.GetNumElements() == 6
143    assert tensor.GetNumBytes() == 12
144    assert tensor.GetDataType() == ann.DataType_Float16
145