Home
last modified time | relevance | path

Searched refs:load_tensor (Results 1 – 8 of 8) sorted by relevance

/aosp_15_r20/external/pytorch/test/distributed/checkpoint/
H A Dtest_file_system_checkpoint_cpu.py195 def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: member in TestDistributedReshardOnLoad
301 store_tensor = self.load_tensor(model_to_save.sharded_tensor)
303 load_tensor = self.load_tensor(model_to_load.sharded_tensor)
307 torch.allclose(store_tensor, load_tensor),
351 store_tensor = self.load_tensor(model_to_save.sharded_tensor)
352 load_tensor = self.load_tensor(model_to_load.sharded_tensor)
355 self.assertTrue(torch.allclose(store_tensor, load_tensor))
448 save_dict_sharded = self.load_tensor(save_dict["sharded"])
449 load_dict_replicated = self.load_tensor(load_dict["replicated"])
H A Dtest_file_system_checkpoint.py218 def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: member in TestDistributedReshardOnLoad
330 store_tensor = self.load_tensor(model_to_save.sharded_tensor)
332 load_tensor = self.load_tensor(model_to_load.sharded_tensor)
336 torch.allclose(store_tensor, load_tensor),
385 store_tensor = self.load_tensor(model_to_save.sharded_tensor)
386 load_tensor = self.load_tensor(model_to_load.sharded_tensor)
389 self.assertTrue(torch.allclose(store_tensor, load_tensor))
483 save_dict_sharded = self.load_tensor(save_dict["sharded"])
484 load_dict_replicated = self.load_tensor(load_dict["replicated"])
/aosp_15_r20/external/pytorch/test/
H A Dtest_content_store.py92 x2 = torch.ops.debugprims.load_tensor.default(
96 x3 = torch.ops.debugprims.load_tensor.default(
112 x4 = torch.ops.debugprims.load_tensor.default(
119 x5 = torch.ops.debugprims.load_tensor.default(
125 x6 = torch.ops.debugprims.load_tensor.default(
/aosp_15_r20/external/pytorch/torch/distributed/fsdp/
H A D_state_dict_utils.py478 load_tensor = state_dict[fqn]
480 load_tensor, ShardedTensor
487 shards = load_tensor.local_shards()
490 load_tensor = shards[0].tensor
495 assert load_tensor.numel() < flat_param.numel(), (
499 load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
501 load_tensor = flat_param
503 state_dict[fqn] = load_tensor
/aosp_15_r20/external/pytorch/torch/_functorch/
H A Dfx_minifier.py81 and node.target is torch.ops.debugprims.load_tensor.default
119 node.target = torch.ops.debugprims.load_tensor.default
388 torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs)
/aosp_15_r20/external/pytorch/torch/onnx/_internal/
H A Donnx_proto_utils.py114 tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
119 tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
/aosp_15_r20/external/pytorch/torch/package/
H A Dpackage_importer.py241 def load_tensor(dtype, size, key, location, restore_location): function
265 load_tensor(
/aosp_15_r20/external/pytorch/torch/
H A Dserialization.py1467 def load_tensor(dtype, numel, key, location): function
1511 typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))