Home
last modified time | relevance | path

Searched refs:UnitModule (Results 1 – 8 of 8) sorted by relevance

/aosp_15_r20/external/pytorch/test/distributed/_composable/fully_shard/
H A Dtest_fully_shard_init.py19 UnitModule,
60 ModuleWrapPolicy({UnitModule}),
181 policy=ModuleWrapPolicy({UnitModule}),
198 policy = ModuleWrapPolicy({UnitModule})
250 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
256 policy=ModuleWrapPolicy({UnitModule}),
290 fully_shard(composable_module, policy=ModuleWrapPolicy({UnitModule}))
H A Dtest_fully_shard_model_checkpoint.py19 UnitModule,
61 fully_shard(save_composable, policy=ModuleWrapPolicy({UnitModule}))
68 copy.deepcopy(local_model), policy=ModuleWrapPolicy({UnitModule})
H A Dtest_fully_shard_util.py12 from torch.testing._internal.common_dist_composable import CompositeModel, UnitModule
44 policy=ModuleWrapPolicy({UnitModule}),
H A Dtest_fully_shard_runtime.py21 UnitModule,
64 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
70 policy=ModuleWrapPolicy({UnitModule}),
H A Dtest_fully_shard_optim_checkpoint.py14 UnitModule,
90 fully_shard(composable_model, policy=ModuleWrapPolicy({UnitModule}))
/aosp_15_r20/external/pytorch/torch/testing/_internal/
H A Dcommon_dist_composable.py11 class UnitModule(nn.Module): class
30 self.u1 = UnitModule(device)
31 self.u2 = UnitModule(device)
57 self.u1 = UnitModule(device)
58 self.u2 = UnitModule(device)
/aosp_15_r20/external/pytorch/test/distributed/checkpoint/
H A Dtest_state_dict.py44 UnitModule,
181 strategy = {UnitModule}
217 "wrapping": [tuple(), (nn.Linear, UnitModule)],
325 fully_shard(dist_model, policy=ModuleWrapPolicy({UnitModule}))
330 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
440 auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
552 UnitModule.get_extra_state = get_extra_state
553 UnitModule.set_extra_state = set_extra_state
/aosp_15_r20/external/pytorch/test/distributed/_composable/
H A Dtest_compose.py18 UnitModule,
87 model = UnitModule(device=torch.device("cuda"))
303 fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
314 fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))