Home
last modified time | relevance | path

Searched refs:scalarToMetalTypeString (Results 1 – 16 of 16) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/
H A DFusedAdamKernelImpl.mm28 …"fused_adam_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
58 …"fused_adam_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
H A DFusedAdamWKernelImpl.mm28 …"fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]…
58 …"fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]…
H A DFusedAdamAmsgradKernelImpl.mm30 …"fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_s…
62 …"fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_s…
H A DFusedAdamWAmsgradKernelImpl.mm30 …"fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_…
61 …"fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_…
H A DFusedSgdKernel.mm34 …const std::string kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_t…
68 const auto kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type());
125 const auto kernel_name = "fused_sgd_" + scalarToMetalTypeString(params[0].scalar_type());
190 …const std::string kernel_name = "fused_sgd_" + mps::scalarToMetalTypeString(params[0].scalar_type(…
H A DUnaryKernel.mm36 … {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
H A DGamma.mm395 …return lib.getPipelineStateForFunc(fname, {scalarToMetalTypeString(t1), scalarToMetalTypeString(t2…
H A DRenormKernel.mm63 string key = "renorm_" + scalarToMetalTypeString(self);
H A DBucketization.mm221 const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input) + "_" +
222 scalarToMetalTypeString(result) + (sorter.defined() ? "_sorter" : "");
H A DCrossKernel.mm107 auto crossPSO = lib.getPipelineStateForFunc("cross_" + scalarToMetalTypeString(out));
H A DQuantized.mm800 …onst std::string kernel = fmt::format("int4pack_mm_{}_{}", qGroupSize, scalarToMetalTypeString(A));
853 kernel = fmt::format("int8pack_mv_{}", scalarToMetalTypeString(A));
855 kernel = fmt::format("large_m_int8pack_mm_{}", scalarToMetalTypeString(A));
H A DHistogramKernel.mm257 const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input);
H A DBinaryKernel.mm272 const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
H A DLinearAlgebra.mm87 …auto matmulPSO = lib.getPipelineStateForFunc("naive_matmul_" + mps::scalarToMetalTypeString(output…
/aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/
H A DOperationUtils.h72 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
73 static inline std::string scalarToMetalTypeString(const Tensor& t) { in scalarToMetalTypeString() function
74 return scalarToMetalTypeString(t.scalar_type()); in scalarToMetalTypeString()
H A DOperationUtils.mm209 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {