1 #pragma once 2 #include <ATen/functorch/Interpreter.h> 3 4 namespace at::functorch { 5 6 // This is the interpreter that handles the functionalize() transform. 7 // See NOTE: [functorch interpreter stack] for more details. 8 9 struct VmapInterpreterPtr { VmapInterpreterPtrVmapInterpreterPtr10 explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); } keyVmapInterpreterPtr11 TransformType key() const { return base_->key(); } levelVmapInterpreterPtr12 int64_t level() const { return base_->level(); } 13 void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack); 14 void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case); batchSizeVmapInterpreterPtr15 c10::SymInt batchSize() const { 16 return std::get<VmapInterpreterMeta>(base_->meta()).batchSize_; 17 } randomnessVmapInterpreterPtr18 RandomnessType randomness() const { 19 return std::get<VmapInterpreterMeta>(base_->meta()).randomness_; 20 } 21 private: 22 const Interpreter* base_; 23 }; 24 25 } // namespace at::functorch 26