xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/VmapInterpreter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/functorch/Interpreter.h>
3 
4 namespace at::functorch {
5 
6 // This is the interpreter that handles the functionalize() transform.
7 // See NOTE: [functorch interpreter stack] for more details.
8 
9 struct VmapInterpreterPtr {
VmapInterpreterPtrVmapInterpreterPtr10   explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
keyVmapInterpreterPtr11   TransformType key() const { return base_->key(); }
levelVmapInterpreterPtr12   int64_t level() const { return base_->level(); }
13   void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
14   void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
batchSizeVmapInterpreterPtr15   c10::SymInt batchSize() const {
16     return std::get<VmapInterpreterMeta>(base_->meta()).batchSize_;
17   }
randomnessVmapInterpreterPtr18   RandomnessType randomness() const {
19     return std::get<VmapInterpreterMeta>(base_->meta()).randomness_;
20   }
21  private:
22   const Interpreter* base_;
23 };
24 
25 } // namespace at::functorch
26