1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10 11from .subclass import from_float, to_float 12 13 14class TestGGMLTensorSubclass(unittest.TestCase): 15 def test_from_float_and_to_float(self) -> None: 16 weight = torch.randn([1, 32], dtype=torch.float16) 17 18 packed = from_float(weight) 19 20 self.assertEqual(packed.dtype, torch.uint8) 21 # expected size = (1 * sizeof(uint16_t)) + (32 * sizeof(uint8_t)/2) = 18 22 self.assertEqual(packed.numel(), 18) 23 24 unpacked = to_float(packed).reshape(weight.shape) 25 26 tolerance = (torch.max(weight) - torch.min(weight)) / 16 27 28 self.assertTrue(torch.allclose(weight, unpacked, atol=tolerance.item())) 29