xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/VulkanPackedContext.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <torch/custom_class.h>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11 
12 class VulkanPackedContext {
13  protected:
14   c10::impl::GenericList packed_;
15 
16  public:
VulkanPackedContext()17   VulkanPackedContext() : packed_{c10::AnyType::get()} {}
18   VulkanPackedContext(const VulkanPackedContext&) = default;
19   VulkanPackedContext(VulkanPackedContext&&) = default;
20 
get_val(int64_t i)21   inline const c10::IValue get_val(int64_t i) const {
22     return packed_.get(i);
23   }
24 
set_val(int64_t i,const c10::IValue & val)25   inline void set_val(int64_t i, const c10::IValue& val) const {
26     return packed_.set(i, val);
27   }
28 
29   virtual const c10::impl::GenericList unpack() const = 0;
30 
31   virtual ~VulkanPackedContext() = default;
32 };
33 
34 } // namespace ops
35 } // namespace vulkan
36 } // namespace native
37 } // namespace at
38 
39 #endif /* USE_VULKAN_API */
40