# Owner(s): ["module: unknown"] import itertools import logging import re import torch from torch import nn from torch.ao.pruning import ( BaseSparsifier, FakeSparsity, NearlyDiagonalSparsifier, WeightNormSparsifier, ) from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_pruning import ( ImplementedSparsifier, MockSparseLinear, SimpleLinear, ) from torch.testing._internal.common_utils import TestCase logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) class TestBaseSparsifier(TestCase): def test_constructor(self): # Cannot instantiate the abstract base self.assertRaises(TypeError, BaseSparsifier) # Can instantiate the model with no configs model = SimpleLinear() sparsifier = ImplementedSparsifier(test=3) sparsifier.prepare(model, config=None) assert len(sparsifier.groups) == 5 sparsifier.step() # Can instantiate the model with configs sparsifier = ImplementedSparsifier(test=3) sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) assert len(sparsifier.groups) == 1 assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight" assert "test" in sparsifier.groups[0] assert sparsifier.groups[0]["test"] == 3 def test_prepare_config(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(test=3) # Make sure there are no parametrizations before `prepare` assert not hasattr(model.seq[0], "parametrizations") assert not hasattr(model.linear1, "parametrizations") assert not hasattr(model.linear2, "parametrizations") sparsifier.prepare( model, config=[ {"tensor_fqn": "seq.0.weight", "test": 42}, # No 'linear1' to make sure it will be skipped in the sparsification {"tensor_fqn": "linear2.weight"}, ], ) assert len(sparsifier.groups) == 2 # Check if default argument is not assigned if explicit assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight" assert sparsifier.groups[0]["test"] == 42 # Check if FQN and module are pointing to the same location assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight" assert sparsifier.groups[1]["module"] == model.linear2 # Check if parameterizations are attached assert hasattr(model.seq[0], "parametrizations") assert not hasattr(model.linear1, "parametrizations") assert hasattr(model.linear2, "parametrizations") def test_step(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(test=3) sparsifier.enable_mask_update = True sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) sparsifier.step() assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0) def test_state_dict(self): step_count = 3 model0 = SimpleLinear() sparsifier0 = ImplementedSparsifier(test=3) sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}]) mask = model0.linear1.parametrizations["weight"][0].mask mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape) for step in range(step_count): sparsifier0.step() state_dict = sparsifier0.state_dict() # Check the expected keys in the state_dict assert "state" in state_dict assert "step_count" in state_dict["state"]["linear1.weight"] assert state_dict["state"]["linear1.weight"]["step_count"] == 3 assert "groups" in state_dict assert "test" in state_dict["groups"][0] assert "tensor_fqn" in state_dict["groups"][0] assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight" # Check loading static_dict creates an equivalent model model1 = SimpleLinear() sparsifier1 = ImplementedSparsifier() sparsifier1.prepare(model1, None) assert sparsifier0.state != sparsifier1.state # Make sure the masks are different in the beginning for mg in sparsifier0.groups: if mg["tensor_fqn"] == "linear1.weight": mask0 = mg["module"].parametrizations.weight[0].mask for mg in sparsifier1.groups: if mg["tensor_fqn"] == "linear1.weight": mask1 = mg["module"].parametrizations.weight[0].mask self.assertNotEqual(mask0, mask1) sparsifier1.load_state_dict(state_dict) # Make sure the states are loaded, and are correct assert sparsifier0.state == sparsifier1.state # Make sure the masks (and all dicts) are the same after loading assert len(sparsifier0.groups) == len(sparsifier1.groups) for idx in range(len(sparsifier0.groups)): mg0 = sparsifier0.groups[idx] mg1 = sparsifier1.groups[idx] for key in mg0.keys(): assert key in mg1 if key == "module": # We cannot compare modules as they are different param0 = mg0[key].parametrizations.weight[0] param1 = mg1[key].parametrizations.weight[0] assert hasattr(param0, "mask") assert hasattr(param1, "mask") self.assertEqual(param0.__dict__, param1.__dict__) else: assert mg0[key] == mg1[key] def test_convert(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(test=3) sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) new_model = sparsifier.convert( model, mapping={nn.Linear: MockSparseLinear}, inplace=False ) assert isinstance(new_model.linear1, MockSparseLinear) assert isinstance(new_model.seq[0], nn.Linear) assert isinstance(new_model.linear2, nn.Linear) def test_mask_squash(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(test=3) sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) assert hasattr(model.linear1.parametrizations.weight[0], "mask") assert is_parametrized(model.linear1, "weight") assert not is_parametrized(model.seq[0], "weight") sparsifier.squash_mask() assert not is_parametrized(model.seq[0], "weight") assert not is_parametrized(model.linear1, "weight") def test_mask_squash_with_params1(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) sparsifier.prepare( model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] ) sparsifier.squash_mask( params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)} ) assert not is_parametrized(model.seq[0], "weight") assert not is_parametrized(model.linear1, "weight") assert hasattr(model.seq[0], "sparse_params") assert hasattr(model.linear1, "sparse_params") assert model.seq[0].sparse_params.get("foo", None) is None assert model.seq[0].sparse_params.get("bar", None) is None assert model.seq[0].sparse_params.get("baz", None) == 1 assert model.linear1.sparse_params.get("foo", None) == 3 assert model.linear1.sparse_params.get("bar", None) == 2 assert model.linear1.sparse_params.get("baz", None) is None def test_mask_squash_with_params2(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) sparsifier.prepare( model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] ) sparsifier.squash_mask(params_to_keep=("foo", "bar")) assert not is_parametrized(model.seq[0], "weight") assert not is_parametrized(model.linear1, "weight") assert hasattr(model.seq[0], "sparse_params") assert hasattr(model.linear1, "sparse_params") assert model.seq[0].sparse_params.get("foo", None) == 3 assert model.seq[0].sparse_params.get("bar", None) == 2 assert model.seq[0].sparse_params.get("baz", None) is None assert model.linear1.sparse_params.get("foo", None) == 3 assert model.linear1.sparse_params.get("bar", None) == 2 assert model.linear1.sparse_params.get("baz", None) is None def test_mask_squash_with_params3(self): model = SimpleLinear() sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) sparsifier.prepare( model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] ) sparsifier.squash_mask( params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)} ) assert not is_parametrized(model.seq[0], "weight") assert not is_parametrized(model.linear1, "weight") assert hasattr(model.seq[0], "sparse_params") assert hasattr(model.linear1, "sparse_params") assert model.seq[0].sparse_params.get("foo", None) == 3 assert model.seq[0].sparse_params.get("bar", None) == 2 assert model.seq[0].sparse_params.get("baz", None) == 1 assert model.linear1.sparse_params.get("foo", None) == 3 assert model.linear1.sparse_params.get("bar", None) == 2 assert model.linear1.sparse_params.get("baz", None) is None class TestWeightNormSparsifier(TestCase): def test_constructor(self): model = SimpleLinear() sparsifier = WeightNormSparsifier() sparsifier.prepare(model, config=None) for g in sparsifier.groups: assert isinstance(g["module"], nn.Linear) # The groups are unordered assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") def test_step(self): model = SimpleLinear() sparsifier = WeightNormSparsifier(sparsity_level=0.5) sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) for g in sparsifier.groups: # Before step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) == 0 # checking sparsity level is 0 sparsifier.enable_mask_update = True sparsifier.step() self.assertAlmostEqual( model.linear1.parametrizations["weight"][0].mask.mean().item(), 0.5, places=2, ) for g in sparsifier.groups: # After step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) > 0 # checking sparsity level has increased # Test if the mask collapses to all zeros if the weights are randomized iters_before_collapse = 1000 for _ in range(iters_before_collapse): model.linear1.weight.data = torch.randn(model.linear1.weight.shape) sparsifier.step() for g in sparsifier.groups: # After step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) > 0 # checking sparsity level did not collapse def test_step_2_of_4(self): model = SimpleLinear() sparsifier = WeightNormSparsifier( sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 ) sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) sparsifier.step() # make sure the sparsity level is approximately 50% mask = model.linear1.parametrizations["weight"][0].mask.to( torch.float ) # mean works on float only self.assertAlmostEqual(mask.mean().item(), 0.5, places=2) # Make sure each block has exactly 50% zeros module = sparsifier.groups[0]["module"] mask = module.parametrizations["weight"][0].mask for row in mask: for idx in range(0, len(row), 4): block = row[idx : idx + 4] block, _ = block.sort() assert (block[:2] == 0).all() assert (block[2:] != 0).all() def test_prepare(self): model = SimpleLinear() sparsifier = WeightNormSparsifier() sparsifier.prepare(model, config=None) for g in sparsifier.groups: module = g["module"] # Check mask exists assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") assert type(module.parametrizations.weight[0]) == FakeSparsity def test_mask_squash(self): model = SimpleLinear() sparsifier = WeightNormSparsifier() sparsifier.prepare(model, config=None) sparsifier.squash_mask() for g in sparsifier.groups: module = g["module"] assert not is_parametrized(module, "weight") assert not hasattr(module, "mask") def test_sparsity_levels(self): sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0] sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)] zeros_per_blocks = [0, 1, 2, 3, 4] testcases = itertools.tee( itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks) ) # Create a config and model with all the testcases model = nn.Sequential() sparsifier = WeightNormSparsifier() sparsity_per_layer_config = [] p = re.compile(r"[-\.\s]") for sl, sbs, zpb in testcases[0]: # Make sure the number of zeros is not > values in a block if zpb > sbs[0] * sbs[1]: continue layer_name = f"{sl}_{sbs}_{zpb}" layer_name = p.sub("_", layer_name) layer = nn.Linear(12, 12, bias=False) layer.weight = nn.Parameter(torch.ones(12, 12)) model.add_module(layer_name, layer) config = { "tensor_fqn": layer_name + ".weight", "sparsity_level": sl, "sparse_block_shape": sbs, "zeros_per_block": zpb, } sparsity_per_layer_config.append(config) sparsifier.prepare(model, sparsity_per_layer_config) sparsifier.step() sparsifier.squash_mask() model.eval() for sl, sbs, zpb in testcases[1]: if zpb > sbs[0] * sbs[1]: continue layer_name = f"{sl}_{sbs}_{zpb}" layer_name = p.sub("_", layer_name) layer = getattr(model, layer_name) # Level of sparsity is achieved sparse_mask = (layer.weight == 0).float() if zpb == 0: assert sparse_mask.mean() == 0 else: # Ratio of individual zeros in the tensor true_sl = min(max(sl, 0.0), 1.0) true_sl = true_sl * zpb / sbs[0] / sbs[1] assert sparse_mask.mean() == true_sl class TestNearlyDiagonalSparsifier(TestCase): def test_constructor(self): model = SimpleLinear() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=None) for g in sparsifier.groups: assert isinstance(g["module"], nn.Linear) # The groups are unordered assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") def test_step(self): model = SimpleLinear() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) for g in sparsifier.groups: # Before step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) == 0 # checking sparsity level is 0 sparsifier.enable_mask_update = True sparsifier.step() mask = module.parametrizations["weight"][0].mask height, width = mask.shape assert torch.all(mask == torch.eye(height, width)) for g in sparsifier.groups: # After step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) > 0 # checking sparsity level has increased # Test if the mask collapses to all zeros if the weights are randomized iters_before_collapse = 1000 for _ in range(iters_before_collapse): model.linear1.weight.data = torch.randn(model.linear1.weight.shape) sparsifier.step() for g in sparsifier.groups: # After step module = g["module"] assert ( 1.0 - module.parametrizations["weight"][0].mask.mean() ) > 0 # checking sparsity level did not collapse def test_prepare(self): model = SimpleLinear() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=None) for g in sparsifier.groups: module = g["module"] # Check mask exists assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") assert type(module.parametrizations.weight[0]) == FakeSparsity def test_mask_squash(self): model = SimpleLinear() sparsifier = NearlyDiagonalSparsifier(nearliness=1) sparsifier.prepare(model, config=None) sparsifier.step() sparsifier.squash_mask() for g in sparsifier.groups: module = g["module"] assert not is_parametrized(module, "weight") assert not hasattr(module, "mask") weights = module.weight height, width = weights.shape assert torch.all( weights == torch.eye(height, width) * weights ) # only diagonal to be present def test_sparsity_levels(self): nearliness_levels = list(range(-1, 100)) model = nn.Sequential() p = re.compile(r"[-\.\s]") for nearliness in nearliness_levels: sparsifier = NearlyDiagonalSparsifier(nearliness=1) layer_name = f"{nearliness}" layer_name = p.sub("_", layer_name) layer = nn.Linear(32, 32, bias=False) layer.weight = nn.Parameter(torch.ones(32, 32)) width, height = layer.weight.shape model.add_module(layer_name, layer) config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness} sparsifier.prepare(model, [config]) # should raise a ValueError when nearliness arg is illegal if (nearliness > 0 and nearliness % 2 == 0) or ( nearliness // 2 >= min(width, height) ): with self.assertRaises(ValueError): sparsifier.step() else: sparsifier.step() sparsifier.squash_mask() model.eval() layer = getattr(model, layer_name) # verify that mask created corresponds to the nearliness self._verify_nearliness(layer.weight, nearliness) # helper function to verify nearliness of a mask def _verify_nearliness(self, mask: torch.Tensor, nearliness: int): if nearliness <= 0: assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1])) else: height, width = mask.shape dist_to_diagonal = nearliness // 2 for row in range(0, height): for col in range(0, width): if abs(row - col) <= dist_to_diagonal: assert mask[row, col] == 1 else: assert mask[row, col] == 0