# Owner(s): ["module: unknown"] import functools import unittest import torch import torch.nn.functional as F import torch.utils.flop_counter from torch._subclasses.fake_tensor import FakeTensorMode from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, ) from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_TORCHDYNAMO, TestCase, skipIfRocm, ) try: from torchvision import models as torchvision_models HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") HAS_CUDA = torch.cuda.is_available() def FlopCounterMode(*args, **kwargs): return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False) def get_total_flops(mode): return str(sum(v for _, v in mode.flop_counts["Global"].items())) def T(*shape, requires_grad=False): return torch.randn(*shape, requires_grad=requires_grad) @unittest.skipIf( TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now" ) class TestFlopCounter(TestCase): def test_flop_counter_variety(self): mod = torch.nn.Linear(9, 10) with FlopCounterMode() as mode: torch.mm(T(4, 5), T(5, 6)) torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5) torch.matmul(T(5, 6), T(6, 7)) torch.einsum("ab,bc->ac", T(6, 7), T(7, 8)) mod(T(8, 9)) self.assertExpectedInline(get_total_flops(mode), """3012""") def test_op(self): with FlopCounterMode() as mode: torch.mm(T(4, 5), T(5, 6)) # 4 * 6 * 2 * 5 = 240 self.assertExpectedInline(get_total_flops(mode), """240""") with mode: torch.bmm(T(3, 4, 5), T(3, 5, 6)) # 3 * 4 * 6 * 2 * 5 = 720 self.assertExpectedInline(get_total_flops(mode), """720""") with mode: torch.addmm(T(4, 6), T(4, 5), T(5, 6)) torch.addmm(T(4, 1), T(4, 5), T(5, 6)) torch.addmm(T(6), T(4, 5), T(5, 6)) # 4 * 6 * 2 * 5 = 240 self.assertExpectedInline(get_total_flops(mode), """720""") with mode: torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6)) # 3 * 4 * 6 * 2 * 5 = 720 self.assertExpectedInline(get_total_flops(mode), """720""") with mode: torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1) # out_image_size = 2 * 5 * 5 # kernel_size = 4 * 4 # c_out = 6 # c_in = 3 # out_image_size * kernel_size * c_out * 2 * c_in # NB: I don't think this properly accounts for padding? self.assertExpectedInline(get_total_flops(mode), """28800""") with mode: torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1) # out_image_size = 2 * 5 # kernel_size = 4 # c_out = 6 # c_in = 3 # out_image_size * kernel_size * c_out * 2 * c_in # NB: I don't think this properly accounts for padding? self.assertExpectedInline(get_total_flops(mode), """1440""") def test_backward(self): with FlopCounterMode() as mode: a = T(4, 5, requires_grad=True) a = torch.mm(a, T(5, 6)) a = a.unsqueeze(0).expand(7, 4, 6) a = torch.bmm(a, T(7, 6, 7)) a.sum().backward() self.assertExpectedInline(get_total_flops(mode), """5184""") def test_backward_reset(self): with FlopCounterMode() as mode: a = T(4, 5, requires_grad=True) a.mm(a.t()).sum().backward() a.mm(a.t()).sum().backward() self.assertExpectedInline(get_total_flops(mode), """960""") def test_torchscript(self): def foo(x): return torch.mm(x, x) with FlopCounterMode() as mode: foo(T(5, 5)) unscripted_flops = get_total_flops(mode) ts_foo = torch.jit.script(foo) with mode: ts_foo(T(5, 5)) self.assertEqual(unscripted_flops, get_total_flops(mode)) def test_autograd_op(self): class _CustomOp(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: return torch.mm(input, input) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return torch.mm(grad_output, grad_output) + torch.mm( grad_output, grad_output ) a = T(5, 5, requires_grad=True) with FlopCounterMode() as mode: a = _CustomOp.apply(a) a.sum().backward() self.assertExpectedInline(get_total_flops(mode), """750""") def test_conv_backwards_as_decomposition(self): # [conv backwards decomposition as conv forwards] class onlyConvs(torch.autograd.Function): @staticmethod def forward(inp, weight, transposed): if not transposed: return F.conv1d(inp, weight) else: return F.conv_transpose1d(inp, weight) @staticmethod def setup_context(ctx, inputs, output): inp, weight, transposed = inputs ctx.save_for_backward(inp, weight) ctx.transposed = transposed @staticmethod def backward(ctx, grad_out): inp, weight = ctx.saved_tensors if not ctx.transposed: grad_inp = F.conv_transpose1d(grad_out, weight) grad_weight = F.conv1d(inp, grad_out) return grad_inp, grad_weight, None else: grad_inp = F.conv1d(grad_out, weight) grad_weight = F.conv1d( grad_out.transpose(1, 0), inp.transpose(1, 0) ) return grad_inp, grad_weight.transpose(1, 0), None from torch.func import grad x = torch.randn(2, 3, 16, dtype=torch.float64) weight = torch.randn(3, 4, 4, dtype=torch.float64) def boring_conv(x, weight, transposed): if not transposed: return F.conv1d(x, weight).pow(2).sum() else: return F.conv_transpose1d(x, weight).pow(2).sum() def only_convs(x, weight, transposed): return onlyConvs.apply(x, weight, transposed).pow(2).sum() boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True) fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True) self.assertEqual(boring_grads, fun_grads) def test_convs(self): def assert_equivalence(f, expected_forward=None): with FlopCounterMode() as mode: f() conv_forward_flops = mode.get_flop_counts()["Global"][ torch.ops.aten.convolution ] conv_backward_flops = mode.get_flop_counts()["Global"][ torch.ops.aten.convolution_backward ] self.assertEqual(conv_forward_flops * 2, conv_backward_flops) if expected_forward is not None: self.assertEqual(conv_forward_flops, expected_forward) x = torch.rand(1, 1, 2, 2, requires_grad=True) weight = torch.randn(1, 1, 2, 2, requires_grad=True) assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32) x = torch.rand(1, 1, 2, 2, requires_grad=True) weight = torch.randn(1, 1, 1, 1, requires_grad=True) assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8) for in_channels, out_channels, groups in [ (1, 1, 1), (1, 3, 1), (3, 1, 1), (3, 7, 1), (2, 4, 2), (4, 2, 2), ]: x = torch.rand(1, in_channels, 4, 4, requires_grad=True) weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True) assert_equivalence(lambda: F.conv2d(x, weight).sum().backward()) transposed_weight = torch.randn( in_channels, out_channels, 2, 2, requires_grad=True ) assert_equivalence( lambda: F.conv_transpose2d(x, transposed_weight).sum().backward() ) @skipIfNoTorchVision def test_module(self): resnet18 = torchvision_models.resnet18() with FlopCounterMode(resnet18) as mode: a = T(1, 3, 224, 224, requires_grad=True) resnet18(a).sum().backward() self.assertExpectedInline(get_total_flops(mode), """10884440064""") layer1_conv_flops = mode.flop_counts["ResNet.layer1"][ torch.ops.aten.convolution ] layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][ torch.ops.aten.convolution_backward ] self.assertExpectedInline(str(layer1_conv_flops), """924844032""") self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""") def test_conv_transpose_loop(self): x = torch.rand(1, 4, 30, 2) model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2) with FlopCounterMode() as mode: for i in range(50): out = model(x) out.sum().backward() self.assertExpectedInline(str(mode.get_total_flops()), """1536000""") def test_custom(self): mode = FlopCounterMode( custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5} ) with mode: a = T(4, 5) a + a self.assertExpectedInline(get_total_flops(mode), """5""") def count(*args, out_val): return out_val.numel() count._get_raw = True mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count}) with mode: a = T(4, 5) a + a self.assertExpectedInline(get_total_flops(mode), """20""") def test_noop(self): with FlopCounterMode() as mode: T(4, 5).cos() @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", ) def test_sdpa(self): batch_size = 4 n_heads = 8 seq_len_q = 128 seq_len_k = 256 head_dim = 64 head_dim_v = 64 dtype = torch.float16 torch.manual_seed(0) def get_flops( batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, backend, with_backward=False, ): query = torch.randn( batch_size, n_heads, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True, ) key = torch.randn( batch_size, n_heads, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True, ) value = torch.randn( batch_size, n_heads, seq_len_k, head_dim_v, device="cuda", dtype=dtype, requires_grad=True, ) if backend == "math": backend = torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=False ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=False, enable_mem_efficient=True ) mode = FlopCounterMode() with backend, mode: out = F.scaled_dot_product_attention( query, key, value, dropout_p=0, is_causal=True ) if with_backward: out.sum().backward() return int(get_total_flops(mode)) # Sets seq_len_q == seq_len_k and dim_q == dim_v run_uniform_flops = functools.partial( get_flops, batch_size, n_heads, seq_len_q, seq_len_q, head_dim, head_dim, dtype, ) flops = [ run_uniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"] ] flops_fw_math, flops_fw_flash, flops_fw_efficient = flops self.assertEqual(flops_fw_math, flops_fw_flash) self.assertEqual(flops_fw_math, flops_fw_efficient) self.assertExpectedInline(str(flops_fw_math), """134217728""") flops = [ run_uniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"] ] flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops self.assertEqual(flops_fw_math * 3, flops_fw_bw_math) self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash) self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient) run_nonuniform_flops = functools.partial( get_flops, batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, ) # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v" non_uniform_backends = ["math", "mem_efficient"] flops = [ run_nonuniform_flops(backend, with_backward=False) for backend in non_uniform_backends ] flops_fw_math, flops_fw_efficient = flops self.assertEqual(flops_fw_math, flops_fw_efficient) self.assertExpectedInline(str(flops_fw_math), """268435456""") flops = [ run_nonuniform_flops(backend, with_backward=True) for backend in non_uniform_backends ] flops_fw_bw_math, flops_fw_bw_efficient = flops self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", ) def test_sdpa_nested_tensor(self): def get_flops(q, k, v, backend, with_backward=False): mode = FlopCounterMode() if backend == "math": backend = torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=False ) elif backend == "flash": backend = torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False ) elif backend == "mem_efficient": backend = torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=False, enable_mem_efficient=True ) with backend, mode: out = F.scaled_dot_product_attention( q, k, v, dropout_p=0, is_causal=True ) if with_backward: if out.is_nested: out.values().sum().backward() else: out.sum().backward() return int(get_total_flops(mode)) def get_nested_inputs( batch_size, n_heads, max_seq_len_q, max_seq_len_k, head_dim, head_dim_v, dtype, ): q_lengths = torch.tensor( [ max_seq_len_q // 4, max_seq_len_q // 4 * 2, max_seq_len_q // 4 * 3, max_seq_len_q // 4 * 4, ] ) k_lengths = torch.tensor( [ max_seq_len_k // 4, max_seq_len_k // 4 * 2, max_seq_len_k // 4 * 3, max_seq_len_k // 4 * 4, ] ) q_offsets, k_offsets = ( torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda() for lengths in (q_lengths, k_lengths) ) q_values = torch.randn( q_offsets[-1], head_dim * n_heads, dtype=dtype, requires_grad=True, device="cuda", ) k_values = torch.randn( k_offsets[-1], head_dim * n_heads, dtype=dtype, requires_grad=True, device="cuda", ) v_values = torch.randn( k_offsets[-1], head_dim_v * n_heads, dtype=dtype, requires_grad=True, device="cuda", ) q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets) k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets) v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets) q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2) v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2) return q, k, v def get_dense_flops(q, k, v, backend, with_backward=False): def split_tensor(x): return ( y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True) for y in x.transpose(1, 2).unbind(0) ) q_tensors = split_tensor(q) k_tensors = split_tensor(k) v_tensors = split_tensor(v) flops = 0 for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors): flops += get_flops( q_i, k_i, v_i, backend=backend, with_backward=with_backward ) return flops uniform_config = { "batch_size": 4, "n_heads": 8, "max_seq_len_q": 128, "max_seq_len_k": 128, "head_dim": 64, "head_dim_v": 64, "dtype": torch.float16, } # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors. differing_config = { "batch_size": 4, "n_heads": 8, "max_seq_len_q": 128, "max_seq_len_k": 256, "head_dim": 64, "head_dim_v": 64, "dtype": torch.float16, } self.assertEqual( get_dense_flops( *get_nested_inputs(**uniform_config), backend="flash", with_backward=False, ), get_flops( *get_nested_inputs(**uniform_config), backend="flash", with_backward=False, ), ) self.assertEqual( get_dense_flops( *get_nested_inputs(**uniform_config), backend="mem_efficient", with_backward=False, ), get_flops( *get_nested_inputs(**uniform_config), backend="mem_efficient", with_backward=False, ), ) self.assertEqual( get_dense_flops( *get_nested_inputs(**differing_config), backend="mem_efficient", with_backward=False, ), get_flops( *get_nested_inputs(**differing_config), backend="mem_efficient", with_backward=False, ), ) self.assertEqual( get_dense_flops( *get_nested_inputs(**uniform_config), backend="flash", with_backward=True, ), get_flops( *get_nested_inputs(**uniform_config), backend="flash", with_backward=True, ), ) self.assertEqual( get_dense_flops( *get_nested_inputs(**uniform_config), backend="mem_efficient", with_backward=True, ), get_flops( *get_nested_inputs(**uniform_config), backend="mem_efficient", with_backward=True, ), ) self.assertEqual( get_dense_flops( *get_nested_inputs(**differing_config), backend="mem_efficient", with_backward=True, ), get_flops( *get_nested_inputs(**differing_config), backend="mem_efficient", with_backward=True, ), ) @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support all SDPA backends (pre-SM80 hardware on CUDA)", ) def test_nested_attention_fake_tensors(self): x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16) offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda") max_seqlen = 40 with FakeTensorMode() as fake_mode: fake_x = fake_mode.from_tensor(x) fake_offsets = fake_mode.from_tensor(offsets) with FlopCounterMode() as fake_flop_counter_mode: torch.ops.aten._flash_attention_forward( fake_x, fake_x, fake_x, fake_offsets, fake_offsets, max_seqlen, max_seqlen, 0.0, False, False, ) dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2) with FlopCounterMode() as real_flop_counter_mode: torch.ops.aten._flash_attention_forward( dense_x, dense_x, dense_x, None, None, max_seqlen, max_seqlen, 0.0, False, False, ) self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode))) def test_addmm_out(self): def f(x): y = torch.zeros(10, 10) return torch.mm(x, x, out=y) with FlopCounterMode() as mode: f(torch.randn(10, 10)) self.assertExpectedInline(get_total_flops(mode), """2000""") def test_hook_registration(self): model = torch.nn.Linear(100, 100) x = torch.randn(3, 100) with FlopCounterMode() as mode: self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1) self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1) model(x).sum().backward() self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0) self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0) def test_pytrees(self): class Foo(torch.nn.Module): def forward(self, x): x = x["a"].relu_() return {"a": torch.mm(x, x)} class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = Foo() self.b = Foo() def forward(self, x): return self.b(self.a(x)) mod = Mod() with FlopCounterMode() as mode: mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() self.assertExpectedInline( (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000""" ) class Mod2(torch.nn.Module): def forward(self, x): return (torch.mm(x, x),) mod = Mod2() with FlopCounterMode() as mode: mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward() self.assertExpectedInline( (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000""" ) def test_warning(self): mod = torch.nn.Linear(2, 2) with self.assertWarnsRegex(UserWarning, "not needed"): FlopCounterMode(mod) def test_custom_op(self): from torch.utils.flop_counter import FlopCounterMode, register_flop_formula @torch.library.custom_op("mylib::foo", mutates_args=()) def foo(x: torch.Tensor) -> torch.Tensor: return x.sin() called = 0 with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"): register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x) @register_flop_formula(torch.ops.mylib.foo) def formula(*args, **kwargs): nonlocal called called += 1 return 9001 x = torch.randn(3) with FlopCounterMode(display=False) as mode: y = foo(x) self.assertEqual(called, 1) self.assertExpectedInline(get_total_flops(mode), """9001""") if __name__ == "__main__": run_tests()