xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Shape.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/MemoryOverlap.h>
6 #include <ATen/cuda/detail/IndexUtils.cuh>
7 #include <ATen/native/cuda/MemoryAccess.cuh>
8 #include <ATen/native/Resize.h>
9 #include <ATen/native/TypeProperties.h>
10 #include <ATen/native/TensorShape.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/Dispatch_v2.h>
13 #include <c10/core/MemoryFormat.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/cat_native.h>
20 #include <ATen/ops/copy_native.h>
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/narrow.h>
24 #endif
25 
26 namespace at::native {
27 
28 constexpr int CAT_ARRAY_BATCH_SIZE = 128;
29 constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;
30 constexpr int ALIGNED_VEC_LOAD_BYTES = 16;
31 
32 namespace {
33 
is_aligned_vec4(const void * ptr)34 inline bool is_aligned_vec4(const void* ptr) {
35   auto iptr = reinterpret_cast<uintptr_t>(ptr);
36   return !(iptr % alignof(int4));
37 }
38 
getCatGrid(ptrdiff_t nTensors,dim3 & grid)39 inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
40   const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
41 
42   // X dim of grid for cat array cooperates on a single tensor in the cat.
43   // Given half of the GPU, full utilization will always occur.
44 
45   // This will have cating two tensors fill the entire grid, but prevent
46   // many threads from needlessly load meta data if their sizes is small.
47 
48   grid = dim3( 2LL * numSM, (long long) nTensors );
49 
50   return true;
51 }
52 
53 template<typename T>
getCatGridRocm(unsigned int max_elements_per_tensor,ptrdiff_t nTensors)54 inline std::tuple<dim3, dim3> getCatGridRocm(unsigned int max_elements_per_tensor,
55   ptrdiff_t nTensors) {
56   constexpr unsigned int threads_per_block = 256;
57   constexpr unsigned int elements_per_thread = 8;
58   constexpr unsigned int max_tb_per_sm = 32;
59 
60   unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread);
61   unsigned int thread_blocks = ceil_div(max_threads, threads_per_block);
62 
63   // Limit the number of thread blocks to prevent too many threads to load the metadata
64   // if they operate on very small tensors.
65 
66   const unsigned int num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
67   thread_blocks = std::min(num_sm * max_tb_per_sm, thread_blocks);
68 
69   dim3 block = dim3(threads_per_block);
70   dim3 grid = dim3(thread_blocks, (long long)nTensors);
71 
72   return std::make_tuple(grid, block);
73 }
74 
75 template<typename T>
getCatGridContig(unsigned int max_elements_per_tensor,ptrdiff_t nTensors)76 inline std::tuple<dim3, dim3> getCatGridContig(unsigned int max_elements_per_tensor,
77   ptrdiff_t nTensors) {
78   constexpr unsigned int threads_per_block = 128;
79   constexpr unsigned int min_aligned_vec_per_thread = 1;
80   constexpr unsigned int max_tb_per_sm = 32;
81 
82   unsigned int elements_per_thread = ALIGNED_VEC_LOAD_BYTES / sizeof(T) *
83     min_aligned_vec_per_thread;
84   unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread);
85   unsigned int thread_blocks = ceil_div(max_threads, threads_per_block);
86 
87   // Limit the number of thread blocks to prevent too many threads to load the metadata
88   // if they operate on very small tensors.
89 
90   const unsigned int num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
91   thread_blocks = std::min(num_sm * max_tb_per_sm, thread_blocks);
92 
93   dim3 block = dim3(threads_per_block);
94   dim3 grid = dim3(thread_blocks, (long long)nTensors);
95 
96   return std::make_tuple(grid, block);
97 }
98 
99 // Similar to any other IndexToOffset calculation for copying along a given
100 // dimension.
101 template <typename IndexType, int Dims>
102 struct CatArrIndexToOffset {
computeat::native::__anon7f64acfd0111::CatArrIndexToOffset103   static inline __device__ IndexType compute(
104       const IndexType tensorSize[Dims],
105       const IndexType tensorStride[Dims],
106       const IndexType dimSize,
107       const unsigned int concatDim,
108       IndexType linearIndex) {
109     // linearIndex is not really linear index, but instead the offset in
110     // input tensor. If the input tensor is contiguous, then this offset
111     // is the linear index, but if the input tensor is channels last, then
112     // it is the linear index of the permuted contiguous tensor
113     IndexType offset = 0;
114 
115     #pragma unroll
116     for (int i = Dims - 1; i >= 1; --i) {
117       IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i];
118       IndexType nextDimIndex = linearIndex / curDimSize;
119       IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
120       IndexType curDimOffset = curDimIndex * tensorStride[i];
121       offset += curDimOffset;
122       linearIndex = nextDimIndex;
123     }
124 
125     return offset + linearIndex * tensorStride[0];
126   }
127 };
128 
129 template<typename IndexType, unsigned int MaxDims>
130 struct TensorSizeStride {
131   IndexType tensorSize[MaxDims];
132   IndexType tensorStride[MaxDims];
133 };
134 
135 /**
136   * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a
137   * grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to
138   * copy each element from each input tensor into the output.
139   *
140   * output: base pointer to the storage associated with the output tensor
141   * inputs: GPU-allocated array of input metadata for each input to concatenate
142   *         in the kernel
143   * os: the size/stride vectors for the output tensor
144   * concatDim: dimension along which we are concatenating
145   * dimStride: the stride of the output tensor at the concatDim
146   *
147   * The most important assumption made is that the input tensors are contiguous.
148   */
149 
150 
151 // pass meta data directly through kernel argument instead of pin memory
152 // In contiguous case, we will not need stride_size, setting it as 1 as placeholder
153 // to pass compile.
154 template <typename T, typename IndexType, int n, int stride_size>
155 struct CatArrInputTensorMetadata {
156   const T* input[n];
157   IndexType offset[n];
158   IndexType dimSize[n];
159   IndexType nElements[n];
160   bool isContiguous[n];
161   TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> tensorStride[stride_size];
162 };
163 
164 template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
CatArrayBatchedCopy(T * output,CatArrInputTensorMetadata<T,IndexType,batch_size,stride_size> inputs,TensorSizeStride<IndexType,CAT_ARRAY_MAX_INPUT_DIMS> os,const int concatDim,IndexType dimStride)165 __global__ void CatArrayBatchedCopy(
166     T* output,
167     CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
168     TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
169     const int concatDim,
170     IndexType dimStride) {
171 
172     IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
173     IndexType nElements = inputs.nElements[blockIdx.y];
174     TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0];
175     bool isContig = inputs.isContiguous[blockIdx.y];
176 
177     if(tid >= nElements) return;
178 
179     const T* data = inputs.input[blockIdx.y];
180     IndexType offset = inputs.offset[blockIdx.y];
181     IndexType dimSize = inputs.dimSize[blockIdx.y];
182     IndexType dataOffset = offset * dimStride;
183 
184     IndexType stride = gridDim.x * blockDim.x;
185 
186     while( tid < nElements){
187       IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
188                     os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
189       if (isContig) {
190         output[dataOffset + elementOffset] = data[tid];
191       } else {
192         IndexType inElementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
193                     ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid);
194         output[dataOffset + elementOffset] = data[inElementOffset];
195       }
196     tid += stride;
197     }
198 }
199 
200 template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
CatArrayBatchedCopy_contig(T * output,CatArrInputTensorMetadata<T,IndexType,batch_size,stride_size> inputs,TensorSizeStride<IndexType,CAT_ARRAY_MAX_INPUT_DIMS> os,const int concatDim,IndexType dimStride)201 __global__ void CatArrayBatchedCopy_contig(
202     T* output,
203     CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
204     TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
205     const int concatDim,
206     IndexType dimStride) {
207 
208     IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
209     IndexType nElements = inputs.nElements[blockIdx.y];
210 
211     if(tid >= nElements) return;
212 
213     const T* data = inputs.input[blockIdx.y];
214     IndexType offset = inputs.offset[blockIdx.y];
215     IndexType dimSize = inputs.dimSize[blockIdx.y];
216     IndexType dataOffset = offset * dimStride;
217 
218     IndexType stride = gridDim.x * blockDim.x;
219 
220     while( tid < nElements){
221       IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
222                     os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
223       output[dataOffset + elementOffset] = data[tid];
224       tid += stride;
225     }
226 }
227 
228 /*
229   Specialized implementation of the CatArrayBatchedCopy written to generate wide memory loads
230   to improve memory bandwidth throughput.
231 */
232 
233 template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
CatArrayBatchedCopy_aligned16_contig(T * output,CatArrInputTensorMetadata<T,IndexType,batch_size,stride_size> inputs,TensorSizeStride<IndexType,CAT_ARRAY_MAX_INPUT_DIMS> os,const int concatDim,IndexType dimStride)234 __global__ void CatArrayBatchedCopy_aligned16_contig(
235     T* output,
236     CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
237     TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
238     const int concatDim,
239     IndexType dimStride) {
240 
241     // This kernel tries to use 128 bit loads
242     constexpr int kILP = ALIGNED_VEC_LOAD_BYTES / sizeof(T);
243     IndexType inputOffset = (blockIdx.x * blockDim.x + threadIdx.x) * kILP;
244     IndexType inputStride = gridDim.x * blockDim.x * kILP;
245 
246     IndexType nElements = inputs.nElements[blockIdx.y];
247     if (inputOffset >= nElements) {
248       return;
249     }
250 
251     const T* data = inputs.input[blockIdx.y];
252     IndexType offset = inputs.offset[blockIdx.y];
253     IndexType dimSize = inputs.dimSize[blockIdx.y];
254     IndexType dataOffset = offset * dimStride;
255 
256     IndexType v_elementOffset[kILP];
257     T reg_data[kILP];
258 
259     while (inputOffset + kILP <= nElements) {
260       for (int i = 0; i < kILP; ++i) {
261         v_elementOffset[i] = CatArrIndexToOffset<IndexType, Dims>::compute(os.tensorSize,
262           os.tensorStride, dimSize, concatDim, inputOffset + i);
263       }
264 
265       using LT = at::native::memory::aligned_vector<T, kILP>;
266       ((LT*)reg_data)[0] = const_cast<LT*>((LT*)(data + inputOffset))[0];
267 
268       #pragma unroll
269       for (int i = 0; i < kILP; ++i) {
270         output[dataOffset + v_elementOffset[i]] = reg_data[i];
271       }
272 
273       inputOffset += inputStride;
274     }
275 
276     // Handle remaining tail in case nElements does not divide
277     // exactly to kILP
278 
279     while (inputOffset < nElements) {
280       v_elementOffset[0] = CatArrIndexToOffset<IndexType, Dims>::compute(os.tensorSize,
281         os.tensorStride, dimSize, concatDim, inputOffset);
282       output[dataOffset + v_elementOffset[0]] = data[inputOffset];
283       inputOffset++;
284     }
285 }
286 
287 template <typename scalar_t, int batch_size, int stride_size>
parallel_cat(const Tensor & out,const MaterializedITensorListRef & inputs,int64_t dimension,int nDims,c10::MemoryFormat memory_format)288 void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, int64_t dimension,
289                   int nDims, c10::MemoryFormat memory_format) {
290   // First, let's set up our kernel parameters. We start with a raw pointer to
291   // the storage for the output Tensor.
292   scalar_t *data = (scalar_t *)(out.mutable_data_ptr());
293   CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
294   TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;
295 
296   // Next, let's initialize the size, stride arrays for the output Tensor.
297   if (memory_format == c10::MemoryFormat::Contiguous) {
298     for (int i = 0; i < nDims; ++i) {
299       outputParam.tensorSize[i] = out.size(i);
300       outputParam.tensorStride[i] = out.stride(i);
301     }
302   } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
303     // permute the semantics of dims from NCHW to NHWC so that the input
304     // tensor is now contiguous
305     outputParam.tensorSize[0] = out.size(0);
306     outputParam.tensorStride[0] = out.stride(0);
307     for (int i = 1; i < nDims - 1; ++i) {
308       outputParam.tensorSize[i] = out.size(i + 1);
309       outputParam.tensorStride[i] = out.stride(i + 1);
310     }
311     outputParam.tensorSize[nDims - 1] = out.size(1);
312     outputParam.tensorStride[nDims - 1] = out.stride(1);
313   } else {
314     TORCH_CHECK(false, "unsupported memory format");
315   }
316 
317   at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
318 
319   // If all batches are contiguous we can call a specialized implementation
320   // which requires the input tensor addresses to be aligned to a
321   // 16 Byte boundary.
322 
323   bool isContig = true;
324   bool isAligned = true;
325   unsigned int max_elements_per_tensor = 0;
326 
327   // Now we loop
328   int batchCounter = 0;
329   int64_t offset = 0;
330   for (unsigned i = 0; i < inputs.size() ; i += batch_size) {
331     for (batchCounter = 0;
332           batchCounter < batch_size &&
333             (i+batchCounter) < inputs.size();
334           ++batchCounter) {
335       int64_t dimSize = 0;
336       // There is a legacy case where a 1-D empty tensor can be concat with
337       // high-dimensional tensor
338       if (inputs[i+batchCounter].get().numel() > 0) {
339         dimSize = inputs[i+batchCounter].get().size(dimension);
340       }
341 
342       catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
343       catMetaData.offset[batchCounter] = offset;
344       catMetaData.dimSize[batchCounter] = dimSize;
345       catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel();
346 
347 #ifdef USE_ROCM
348       // On ROCm, CatArrayBatchedCopy_contig is faster
349       isAligned = false;
350 #else
351       // If at least one of the inputs is not aligned, we can't call the
352       // CatArrayBatchedCopy_aligned16_contig
353       isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]);
354 #endif
355 
356       if (stride_size > 1) {
357         auto strides = inputs[i+batchCounter].get().strides();
358         auto sizes = inputs[i+batchCounter].get().sizes();
359         for(int j = 0; j < nDims; j++){
360           catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j];
361           catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
362         }
363         catMetaData.isContiguous[batchCounter] = false;
364         isContig = false;
365       } else {
366         catMetaData.isContiguous[batchCounter] = true;
367       }
368 
369       // Update offset
370       offset += dimSize;
371 
372       // We need max elements per tensor to compute grid parameters
373       max_elements_per_tensor = std::max(max_elements_per_tensor,
374         catMetaData.nElements[batchCounter]);
375     }
376 
377     // Skip if the tensor is empty. Otherwise, the grid dim is invalid
378     if (max_elements_per_tensor == 0)
379       continue;
380 
381     dim3 applyBlock, catGrid;
382 
383 #ifdef USE_ROCM
384     // always base grid size on max_elements_per_tensor
385     {
386       std::tuple<dim3, dim3> launchParams = getCatGridRocm<scalar_t>(
387           max_elements_per_tensor, batchCounter);
388       catGrid = std::get<0>(launchParams);
389       applyBlock = std::get<1>(launchParams);
390     }
391 #else
392     if (isContig && sizeof(scalar_t) > 2) {
393       std::tuple<dim3, dim3> launchParams = getCatGridContig<scalar_t>(
394           max_elements_per_tensor, batchCounter);
395       catGrid = std::get<0>(launchParams);
396       applyBlock = std::get<1>(launchParams);
397     } else {
398       applyBlock = dim3(32 * 16);
399       getCatGrid(batchCounter, catGrid);
400     }
401 #endif
402 
403     if (memory_format != c10::MemoryFormat::Contiguous) {
404       switch (dimension) {
405       case 0:
406         break;
407       case 1:
408         dimension = nDims - dimension;
409         break;
410       default:
411         dimension--;
412       }
413     }
414     // Template Declarations for dim = 1, 2, 3, 4
415 #define HANDLE_CASE(DIMS) \
416     if (isContig && isAligned && sizeof(scalar_t) >= 4 && sizeof(scalar_t) <= 8) {\
417       CatArrayBatchedCopy_aligned16_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
418           catGrid, applyBlock, 0, stream.stream()>>>(\
419               data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
420     } else if (isContig) {\
421       CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
422           catGrid, applyBlock, 0, stream.stream()>>>(\
423               data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
424     } else {\
425       CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
426           catGrid, applyBlock, 0, stream.stream()>>>(\
427               data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
428     }\
429     C10_CUDA_KERNEL_LAUNCH_CHECK();
430     switch (nDims) {
431       case 1:
432         HANDLE_CASE(1);
433         break;
434       case 2:
435         HANDLE_CASE(2);
436         break;
437       case 3:
438         HANDLE_CASE(3);
439         break;
440       case 4:
441         HANDLE_CASE(4);
442         break;
443     }
444 #undef HANDLE_CASE
445   }
446 }
447 // The kernels are templated on an opaque, self-aligned type of the correct
448 // size to avoid redundant kernels for different types of the same size.
449 template <unsigned N> struct alignas(N) OpaqueType { char data[N]; };
450 
451 } // namespace
452 
TORCH_IMPL_FUNC(cat_out_cuda)453 TORCH_IMPL_FUNC(cat_out_cuda)
454 (const ITensorListRef& tensors,
455  int64_t dim,
456  int64_t valid,
457  bool all_contiguous,
458  bool all_same_dtype,
459  bool all_same_sizes_and_stride,
460  MemoryFormat memory_format,
461  const Tensor& result) {
462   if (result.numel() == 0) {
463     return;
464   }
465 
466   auto materialized = tensors.materialize();
467 
468   // We parallelize the copy if all 6 conditions pass:
469   //
470   // 1. There is more than one input tensor
471   // 2. The out tensor is 32-bit indexable
472   // 3. The number of dimensions is <= 4
473   // 4. All input tensors are contiguous (output tensor may be non-contig)
474   // 5. All input tensors can use 32-bit indexing
475 
476   const bool all32BitIndexable = std::all_of(materialized.begin(), materialized.end(),
477     [] (const Tensor& t) {
478       return at::cuda::detail::canUse32BitIndexMath(t);
479     });
480 
481   int nDims = materialized[valid].get().dim();
482 
483   // We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways
484   // For contiguous input, we don't need to pass stride meta data to cuda kernel through constant
485   // memory. Therefore, we could pass more inputs to cuda threads.
486   // For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
487   // of constant memory.
488 
489 
490 
491   if (materialized.size() > 1 &&
492       result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
493       at::cuda::detail::canUse32BitIndexMath(result) &&
494       all_contiguous &&
495       all32BitIndexable &&
496       all_same_dtype) {
497       if (isBitsType(result.scalar_type())) {
498         AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
499           using dtype = OpaqueType<sizeof(scalar_t)>;
500           parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
501         });
502       } else {
503         AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
504           using dtype = OpaqueType<sizeof(scalar_t)>;
505           parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
506         }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
507       }
508   } else if (materialized.size() > 1 &&
509       result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
510       at::cuda::detail::canUse32BitIndexMath(result) &&
511       nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
512       all32BitIndexable &&
513       all_same_dtype &&
514       memory_format == c10::MemoryFormat::Contiguous) {
515       if (isBitsType(result.scalar_type())) {
516         AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
517           using dtype = OpaqueType<sizeof(scalar_t)>;
518           parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
519         });
520       } else {
521         AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
522             using dtype = OpaqueType<sizeof(scalar_t)>;
523             parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
524         }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
525       }
526   } else {
527     int64_t offset = 0;
528     for (const Tensor& t : materialized) {
529       if (cat_should_skip_tensor(t)) continue;
530       int64_t dimSize = t.size(dim);
531       Tensor nt = at::narrow(result, dim, offset, dimSize);
532       copy_(nt, t);
533       offset += dimSize;
534     }
535   }
536 }
537 
538 } // namespace at::native
539