xref: /aosp_15_r20/external/executorch/examples/models/llama/experimental/test_subclass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10
11from .subclass import from_float, to_float
12
13
14class TestGGMLTensorSubclass(unittest.TestCase):
15    def test_from_float_and_to_float(self) -> None:
16        weight = torch.randn([1, 32], dtype=torch.float16)
17
18        packed = from_float(weight)
19
20        self.assertEqual(packed.dtype, torch.uint8)
21        # expected size = (1 * sizeof(uint16_t)) + (32 * sizeof(uint8_t)/2) = 18
22        self.assertEqual(packed.numel(), 18)
23
24        unpacked = to_float(packed).reshape(weight.shape)
25
26        tolerance = (torch.max(weight) - torch.min(weight)) / 16
27
28        self.assertTrue(torch.allclose(weight, unpacked, atol=tolerance.item()))
29