xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/hypothesis_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from collections import defaultdict
4from collections.abc import Iterable
5import numpy as np
6import torch
7
8import hypothesis
9from functools import reduce
10from hypothesis import assume
11from hypothesis import settings
12from hypothesis import strategies as st
13from hypothesis.extra import numpy as stnp
14from hypothesis.strategies import SearchStrategy
15
16from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
17
18# Setup for the hypothesis tests.
19# The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
20# element is enforced zero_point. If None, any zero_point point within the
21# range of the data type is OK.
22
23# Tuple with all quantized data types.
24_ALL_QINT_TYPES = (
25    torch.quint8,
26    torch.qint8,
27    torch.qint32,
28)
29
30# Enforced zero point for every quantized data type.
31# If None, any zero_point point within the range of the data type is OK.
32_ENFORCED_ZERO_POINT = defaultdict(lambda: None, {
33    torch.quint8: None,
34    torch.qint8: None,
35    torch.qint32: 0
36})
37
38def _get_valid_min_max(qparams):
39    scale, zero_point, quantized_type = qparams
40    adjustment = 1 + torch.finfo(torch.float).eps
41    _long_type_info = torch.iinfo(torch.long)
42    long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
43    # make sure intermediate results are within the range of long
44    min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point))
45    max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point))
46    return np.float32(min_value), np.float32(max_value)
47
48# This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if
49# it is too old, removes the `width` parameter (which was introduced)
50# in 3.67.0
51def _floats_wrapper(*args, **kwargs):
52    if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0):
53        # As long as nan, inf, min, max are not specified, reimplement the width
54        # parameter for older versions of hypothesis.
55        no_nan_and_inf = (
56            (('allow_nan' in kwargs and not kwargs['allow_nan']) or
57             'allow_nan' not in kwargs) and
58            (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or
59             'allow_infinity' not in kwargs))
60        min_and_max_not_specified = (
61            len(args) == 0 and
62            'min_value' not in kwargs and
63            'max_value' not in kwargs
64        )
65        if no_nan_and_inf and min_and_max_not_specified:
66            if kwargs['width'] == 16:
67                kwargs['min_value'] = torch.finfo(torch.float16).min
68                kwargs['max_value'] = torch.finfo(torch.float16).max
69            elif kwargs['width'] == 32:
70                kwargs['min_value'] = torch.finfo(torch.float32).min
71                kwargs['max_value'] = torch.finfo(torch.float32).max
72            elif kwargs['width'] == 64:
73                kwargs['min_value'] = torch.finfo(torch.float64).min
74                kwargs['max_value'] = torch.finfo(torch.float64).max
75        kwargs.pop('width')
76    return st.floats(*args, **kwargs)
77
78def floats(*args, **kwargs):
79    if 'width' not in kwargs:
80        kwargs['width'] = 32
81    return _floats_wrapper(*args, **kwargs)
82
83"""Hypothesis filter to avoid overflows with quantized tensors.
84
85Args:
86    tensor: Tensor of floats to filter
87    qparams: Quantization parameters as returned by the `qparams`.
88
89Returns:
90    True
91
92Raises:
93    hypothesis.UnsatisfiedAssumption
94
95Note: This filter is slow. Use it only when filtering of the test cases is
96      absolutely necessary!
97"""
98def assume_not_overflowing(tensor, qparams):
99    min_value, max_value = _get_valid_min_max(qparams)
100    assume(tensor.min() >= min_value)
101    assume(tensor.max() <= max_value)
102    return True
103
104"""Strategy for generating the quantization parameters.
105
106Args:
107    dtypes: quantized data types to sample from.
108    scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3.
109    zero_point_min / zero_point_max: Min and max for the zero point. If None,
110        set to the minimum and maximum of the quantized data type.
111        Note: The min and max are only valid if the zero_point is not enforced
112              by the data type itself.
113
114Generates:
115    scale: Sampled scale.
116    zero_point: Sampled zero point.
117    quantized_type: Sampled quantized type.
118"""
119@st.composite
120def qparams(draw, dtypes=None, scale_min=None, scale_max=None,
121            zero_point_min=None, zero_point_max=None):
122    if dtypes is None:
123        dtypes = _ALL_QINT_TYPES
124    if not isinstance(dtypes, (list, tuple)):
125        dtypes = (dtypes,)
126    quantized_type = draw(st.sampled_from(dtypes))
127
128    _type_info = torch.iinfo(quantized_type)
129    qmin, qmax = _type_info.min, _type_info.max
130
131    # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
132    _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
133    if _zp_enforced is not None:
134        zero_point = _zp_enforced
135    else:
136        _zp_min = qmin if zero_point_min is None else zero_point_min
137        _zp_max = qmax if zero_point_max is None else zero_point_max
138        zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))
139
140    if scale_min is None:
141        scale_min = torch.finfo(torch.float).eps
142    if scale_max is None:
143        scale_max = torch.finfo(torch.float).max
144    scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))
145
146    return scale, zero_point, quantized_type
147
148"""Strategy to create different shapes.
149Args:
150    min_dims / max_dims: minimum and maximum rank.
151    min_side / max_side: minimum and maximum dimensions per rank.
152
153Generates:
154    Possible shapes for a tensor, constrained to the rank and dimensionality.
155
156Example:
157    # Generates 3D and 4D tensors.
158    @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4))
159    some_test(self, Q):...
160"""
161@st.composite
162def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
163    """Return a strategy for array shapes (tuples of int >= 1)."""
164    assert min_dims < 32
165    if max_dims is None:
166        max_dims = min(min_dims + 2, 32)
167    assert max_dims < 32
168    if max_side is None:
169        max_side = min_side + 5
170    candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
171    if max_numel is not None:
172        candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel)
173    return draw(candidate.map(tuple))
174
175
176"""Strategy for generating test cases for tensors.
177The resulting tensor is in float32 format.
178
179Args:
180    shapes: Shapes under test for the tensor. Could be either a hypothesis
181            strategy, or an iterable of different shapes to sample from.
182    elements: Elements to generate from for the returned data type.
183              If None, the strategy resolves to float within range [-1e6, 1e6].
184    qparams: Instance of the qparams strategy. This is used to filter the tensor
185             such that the overflow would not happen.
186
187Generates:
188    X: Tensor of type float32. Note that NaN and +/-inf is not included.
189    qparams: (If `qparams` arg is set) Quantization parameters for X.
190        The returned parameters are `(scale, zero_point, quantization_type)`.
191        (If `qparams` arg is None), returns None.
192"""
193@st.composite
194def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32):
195    if isinstance(shapes, SearchStrategy):
196        _shape = draw(shapes)
197    else:
198        _shape = draw(st.sampled_from(shapes))
199    if qparams is None:
200        if elements is None:
201            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
202        X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
203        assume(not (np.isnan(X).any() or np.isinf(X).any()))
204        return X, None
205    qparams = draw(qparams)
206    if elements is None:
207        min_value, max_value = _get_valid_min_max(qparams)
208        elements = floats(min_value, max_value, allow_infinity=False,
209                          allow_nan=False, width=32)
210    X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
211    # Recompute the scale and zero_points according to the X statistics.
212    scale, zp = _calculate_dynamic_qparams(X, qparams[2])
213    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
214    if enforced_zp is not None:
215        zp = enforced_zp
216    return X, (scale, zp, qparams[2])
217
218@st.composite
219def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
220    if isinstance(shapes, SearchStrategy):
221        _shape = draw(shapes)
222    else:
223        _shape = draw(st.sampled_from(shapes))
224    if qparams is None:
225        if elements is None:
226            elements = floats(-1e6, 1e6, allow_nan=False, width=32)
227        X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
228        assume(not (np.isnan(X).any() or np.isinf(X).any()))
229        return X, None
230    qparams = draw(qparams)
231    if elements is None:
232        min_value, max_value = _get_valid_min_max(qparams)
233        elements = floats(min_value, max_value, allow_infinity=False,
234                          allow_nan=False, width=32)
235    X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
236    # Recompute the scale and zero_points according to the X statistics.
237    scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
238    enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
239    if enforced_zp is not None:
240        zp = enforced_zp
241    # Permute to model quantization along an axis
242    axis = int(np.random.randint(0, X.ndim, 1))
243    permute_axes = np.arange(X.ndim)
244    permute_axes[0] = axis
245    permute_axes[axis] = 0
246    X = np.transpose(X, permute_axes)
247
248    return X, (scale, zp, axis, qparams[2])
249
250"""Strategy for generating test cases for tensors used in Conv.
251The resulting tensors is in float32 format.
252
253Args:
254    spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly
255                 picks one from the pool to make it the spatial dimension
256    batch_size_range: Range to generate `batch_size`.
257                      Must be tuple of `(min, max)`.
258    input_channels_per_group_range:
259        Range to generate `input_channels_per_group`.
260        Must be tuple of `(min, max)`.
261    output_channels_per_group_range:
262        Range to generate `output_channels_per_group`.
263        Must be tuple of `(min, max)`.
264    feature_map_range: Range to generate feature map size for each spatial_dim.
265                       Must be tuple of `(min, max)`.
266    kernel_range: Range to generate kernel size for each spatial_dim. Must be
267                  tuple of `(min, max)`.
268    max_groups: Maximum number of groups to generate.
269    elements: Elements to generate from for the returned data type.
270              If None, the strategy resolves to float within range [-1e6, 1e6].
271    qparams: Strategy for quantization parameters. for X, w, and b.
272             Could be either a single strategy (used for all) or a list of
273             three strategies for X, w, b.
274Generates:
275    (X, W, b, g): Tensors of type `float32` of the following drawen shapes:
276        X: (`batch_size, input_channels, H, W`)
277        W: (`output_channels, input_channels_per_group) + kernel_shape
278        b: `(output_channels,)`
279        groups: Number of groups the input is divided into
280Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either
281      None or (scale, zero_point, quantized_type)
282
283
284Example:
285    @given(tensor_conv(
286        spatial_dim=2,
287        batch_size_range=(1, 3),
288        input_channels_per_group_range=(1, 7),
289        output_channels_per_group_range=(1, 7),
290        feature_map_range=(6, 12),
291        kernel_range=(3, 5),
292        max_groups=4,
293        elements=st.floats(-1.0, 1.0),
294        qparams=qparams()
295    ))
296"""
297@st.composite
298def tensor_conv(
299    draw, spatial_dim=2, batch_size_range=(1, 4),
300    input_channels_per_group_range=(3, 7),
301    output_channels_per_group_range=(3, 7), feature_map_range=(6, 12),
302    kernel_range=(3, 7), max_groups=1, can_be_transposed=False,
303    elements=None, qparams=None
304):
305
306    # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW
307    batch_size = draw(st.integers(*batch_size_range))
308    input_channels_per_group = draw(
309        st.integers(*input_channels_per_group_range))
310    output_channels_per_group = draw(
311        st.integers(*output_channels_per_group_range))
312    groups = draw(st.integers(1, max_groups))
313    input_channels = input_channels_per_group * groups
314    output_channels = output_channels_per_group * groups
315
316    if isinstance(spatial_dim, Iterable):
317        spatial_dim = draw(st.sampled_from(spatial_dim))
318
319    feature_map_shape = []
320    for i in range(spatial_dim):
321        feature_map_shape.append(draw(st.integers(*feature_map_range)))
322
323    kernels = []
324    for i in range(spatial_dim):
325        kernels.append(draw(st.integers(*kernel_range)))
326
327    tr = False
328    weight_shape = (output_channels, input_channels_per_group) + tuple(kernels)
329    bias_shape = output_channels
330    if can_be_transposed:
331        tr = draw(st.booleans())
332        if tr:
333            weight_shape = (input_channels, output_channels_per_group) + tuple(kernels)
334            bias_shape = output_channels
335
336    # Resolve the tensors
337    if qparams is not None:
338        if isinstance(qparams, (list, tuple)):
339            assert len(qparams) == 3, "Need 3 qparams for X, w, b"
340        else:
341            qparams = [qparams] * 3
342
343    X = draw(tensor(shapes=(
344        (batch_size, input_channels) + tuple(feature_map_shape),),
345        elements=elements, qparams=qparams[0]))
346    W = draw(tensor(shapes=(weight_shape,), elements=elements,
347                    qparams=qparams[1]))
348    b = draw(tensor(shapes=(bias_shape,), elements=elements,
349                    qparams=qparams[2]))
350
351    return X, W, b, groups, tr
352
353# We set the deadline in the currently loaded profile.
354# Creating (and loading) a separate profile overrides any settings the user
355# already specified.
356hypothesis_version = hypothesis.version.__version_info__
357current_settings = settings._profiles[settings._current_profile].__dict__
358current_settings['deadline'] = None
359if hypothesis_version >= (3, 16, 0) and hypothesis_version < (5, 0, 0):
360    current_settings['timeout'] = hypothesis.unlimited
361def assert_deadline_disabled():
362    if hypothesis_version < (3, 27, 0):
363        import warnings
364        warning_message = (
365            "Your version of hypothesis is outdated. "
366            "To avoid `DeadlineExceeded` errors, please update. "
367            f"Current hypothesis version: {hypothesis.__version__}"
368        )
369        warnings.warn(warning_message)
370    else:
371        assert settings().deadline is None
372