xref: /aosp_15_r20/external/pytorch/c10/core/StorageImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/StorageImpl.h>
2 #include <c10/util/flat_hash_map.h>
3 
4 namespace c10 {
5 
6 // The array to save function pointer for custom storageImpl create.
7 C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
8     StorageImplCreate;
9 
10 // A allowlist of device type, currently available is PrivateUse1
GetBackendMetaAllowlist()11 inline ska::flat_hash_set<c10::DeviceType>& GetBackendMetaAllowlist() {
12   static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
13       DeviceType::PrivateUse1};
14   return DeviceTypeAllowList;
15 }
16 
throwNullDataPtrError()17 void throwNullDataPtrError() {
18   TORCH_CHECK(
19       false,
20       "Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
21       "If you're using torch.compile/export/fx, it is likely that we are erroneously "
22       "tracing into a custom kernel. To fix this, please wrap the custom kernel into "
23       "an opaque custom op. Please see the following for details: "
24       "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
25 }
26 
27 // NOTE: [FakeTensor.data_ptr deprecation]
28 // Today:
29 // - FakeTensor.data_ptr errors out in torch.compile.
30 // - FakeTensor.data_ptr raises the following deprecation warning otherwise.
31 // - the following deprecation warning is only for FakeTensor (for now).
32 //   In the future we can consider extending to more wrapper Tensor subclasses.
warnDeprecatedDataPtr()33 void warnDeprecatedDataPtr() {
34   TORCH_WARN_ONCE(
35       "Accessing the data pointer of FakeTensor is deprecated and will error in "
36       "PyTorch 2.5. This is almost definitely a bug in your code and will "
37       "cause undefined behavior with subsystems like torch.compile. "
38       "Please wrap calls to tensor.data_ptr() in an opaque custom op; "
39       "If all else fails, you can guard accesses to tensor.data_ptr() on "
40       "isinstance(tensor, FakeTensor).")
41 }
42 
SetStorageImplCreate(DeviceType t,StorageImplCreateHelper fptr)43 void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
44   // Allowlist verification.
45   // Only if the devicetype is in the allowlist,
46   // we allow the extension to be registered for storageImpl create.
47   const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
48   TORCH_CHECK(
49       DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
50       "It is only allowed to register the storageImpl create method ",
51       "for PrivateUse1. ",
52       "If you have related storageImpl requirements, ",
53       "please expand the allowlist");
54   // Register function pointer.
55   int device_type = static_cast<int>(t);
56   TORCH_CHECK(
57       StorageImplCreate[device_type] == nullptr,
58       "The StorageImplCreate function pointer for ",
59       t,
60       " has been registered.");
61   StorageImplCreate[device_type] = fptr;
62 }
63 
GetStorageImplCreate(DeviceType t)64 StorageImplCreateHelper GetStorageImplCreate(DeviceType t) {
65   int device_type = static_cast<int>(t);
66   return StorageImplCreate[device_type];
67 }
68 
make_storage_impl(c10::StorageImpl::use_byte_size_t use_byte_size,c10::SymInt size_bytes,c10::DataPtr data_ptr,c10::Allocator * allocator,bool resizable,std::optional<at::Device> device_opt)69 c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
70     c10::StorageImpl::use_byte_size_t use_byte_size,
71     c10::SymInt size_bytes,
72     c10::DataPtr data_ptr,
73     c10::Allocator* allocator,
74     bool resizable,
75     std::optional<at::Device> device_opt) {
76   // This will be non-nullptr only when there is a custom StorageImpl
77   // constructor for the given device
78   c10::StorageImplCreateHelper fptr = nullptr;
79   if (device_opt.has_value()) {
80     // We only need to check this here as this is the only case where we can
81     // have a device that is not CPU (and thus for which the StorageImpl
82     // constructor can be overwritten).
83     fptr = c10::GetStorageImplCreate(device_opt.value().type());
84   }
85 
86   if (fptr != nullptr) {
87     return fptr(
88         use_byte_size,
89         std::move(size_bytes),
90         std::move(data_ptr),
91         allocator,
92         resizable);
93   }
94 
95   // Create a c10::StorageImpl object.
96   if (data_ptr != nullptr) {
97     return c10::make_intrusive<c10::StorageImpl>(
98         use_byte_size,
99         std::move(size_bytes),
100         std::move(data_ptr),
101         allocator,
102         resizable);
103   }
104   return c10::make_intrusive<c10::StorageImpl>(
105       use_byte_size, std::move(size_bytes), allocator, resizable);
106 }
107 
108 } // namespace c10
109