#define TORCH_ASSERT_ONLY_METHOD_OPERATORS // ${generated_comment} #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else // needed for the meta tensor calls to get stride info in functionalization #include // needed for special handling of copy_(). // See Note [functionalizating copy_() and not preserving strides] #include #include $ops_headers #endif namespace at { namespace functionalization { // This keyset is used by functionalization when it calls into meta kernels // to accurately propagate stride metadata. // Exclude any modes: the purpose of calling into meta kernels is only as an implementation // detail to perform shape inference, and we don't want any modal keys to run. // Specifically, we want to prevent functionalization and Python modes from running. constexpr auto exclude_keys_for_meta_dispatch = c10::functorch_transforms_ks | c10::DispatchKeySet({ c10::DispatchKey::FuncTorchDynamicLayerBackMode, c10::DispatchKey::FuncTorchDynamicLayerFrontMode, c10::DispatchKey::Python, c10::DispatchKey::PreDispatch, }); // Helper around at::has_internal_overlap. // The ATen util is used in hot-path eager mode: it's always fast, // but might return TOO_HARD sometimes. // During functionalization, we're ok taking a bit longer // to detect memory overlap. inline bool has_internal_overlap_helper(const at::Tensor t) { auto has_overlap = at::has_internal_overlap(t); if (has_overlap == at::MemOverlap::Yes) return true; if (has_overlap == at::MemOverlap::No) return false; return false; } inline Tensor to_meta(const Tensor& t) { if (!t.defined()) return t; return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(), /*dtype=*/std::make_optional(t.scalar_type()), /*layout=*/std::make_optional(t.layout()), /*device=*/std::make_optional(c10::Device(kMeta)), /*pin_memory=*/std::nullopt); } inline std::optional to_meta(const std::optional& t) { if (t.has_value()) { return std::make_optional(to_meta(*t)); } return std::nullopt; } inline std::vector to_meta(at::ITensorListRef t_list) { std::vector outputs; outputs.reserve(t_list.size()); for (const auto& tensor : t_list) { outputs.push_back(to_meta(tensor)); } return outputs; } inline c10::List to_meta(const c10::List& t_list) { c10::List outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(to_meta(t_list[i])); } return outputs; } inline c10::List<::std::optional> to_meta(const c10::List<::std::optional>& t_list) { c10::List<::std::optional> outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(to_meta(t_list[i])); } return outputs; } ${func_definitions} } // namespace functionalization namespace { TORCH_LIBRARY_IMPL(aten, Functionalize, m) { ${func_registrations}; } } // namespace } // namespace at