#include #include #include #include #include #include #include #include #include #include namespace torch::utils { // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs static c10::TensorOptions typeIdWithDefault( PythonArgs& r, int device_idx, c10::DispatchKey dispatch_key) { auto options = dispatchKeyToTensorOptions(dispatch_key); if (!r.isNone(device_idx)) { options = options.device(r.device(device_idx)); } return options; } at::Tensor nested_tensor_ctor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, torch::PythonArgs& r) { TORCH_CHECK(r.idx == 0, "nested_tensor(): invalid arguments"); PyObject* data = r.pyobject(0); // Check if data is a list: Only List[Tensor] and List[List...[Scalar]] are // accepted for now TORCH_CHECK_TYPE( PyList_Check(data), "Only lists (List[Tensor] and List[List...[Scalar]]) are accepted in nested_tensor"); auto dtype_val = r.scalartypeWithDefault(1, scalar_type); auto tensor_options = typeIdWithDefault(r, 2, dispatch_key); bool pin_memory = r.toBool(3); bool args_requires_grad = r.toBool(4); TORCH_CHECK( PyList_Size(data) >= 0, "Something went really wrong and your list has negative size"); // Check whether we are dealing with lists of tensors or not std::vector new_list(PyList_Size(data)); for (const auto i : c10::irange(PyList_Size(data))) { PyObject* elem = PyList_GetItem(data, i); if (THPVariable_Check(elem)) { new_list[i] = THPVariable_Unpack(PyList_GetItem(data, i)).detach(); TORCH_CHECK( !new_list[i].is_nested(), "We do not accept nested tensors as input to nested tensors"); TORCH_CHECK( new_list[i].layout() == kStrided, "We do not accept non-strided layouts as input to nested tensors"); } else { PythonArgs elem_r(r); std::array elem_args = { elem, // data r.args[1], // dtpye nullptr, // device (cpu) nullptr, // no pinned memory r.args[4], // requires grad nullptr // names }; elem_r.args = elem_args.data(); new_list[i] = tensor_ctor(dispatch_key, scalar_type, elem_r); } } at::ScalarType final_dtype = dtype_val; if (r.isNone(1) && !new_list.empty()) { final_dtype = c10::typeMetaToScalarType(new_list[0].dtype()); } at::Device final_device = tensor_options.device(); if (r.isNone(2) && !new_list.empty()) { final_device = new_list[0].device(); } auto out = at::_nested_tensor_from_tensor_list( new_list, final_dtype, std::nullopt, final_device, pin_memory); out.requires_grad_(args_requires_grad); return out; } } // namespace torch::utils