Searched refs:scalarToMetalTypeString (Results 1 – 16 of 16) sorted by relevance
28 …"fused_adam_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);58 …"fused_adam_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]);
28 …"fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]…58 …"fused_adamw_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_steps[0]…
30 …"fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_s…62 …"fused_adam_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_s…
30 …"fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_…61 …"fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0]) + "_" + scalarToMetalTypeString(state_…
34 …const std::string kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_t…68 const auto kernel_name = "fused_sgd_momentum_" + scalarToMetalTypeString(params[0].scalar_type());125 const auto kernel_name = "fused_sgd_" + scalarToMetalTypeString(params[0].scalar_type());190 …const std::string kernel_name = "fused_sgd_" + mps::scalarToMetalTypeString(params[0].scalar_type(…
36 … {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
395 …return lib.getPipelineStateForFunc(fname, {scalarToMetalTypeString(t1), scalarToMetalTypeString(t2…
63 string key = "renorm_" + scalarToMetalTypeString(self);
221 const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input) + "_" +222 scalarToMetalTypeString(result) + (sorter.defined() ? "_sorter" : "");
107 auto crossPSO = lib.getPipelineStateForFunc("cross_" + scalarToMetalTypeString(out));
800 …onst std::string kernel = fmt::format("int4pack_mm_{}_{}", qGroupSize, scalarToMetalTypeString(A));853 kernel = fmt::format("int8pack_mv_{}", scalarToMetalTypeString(A));855 kernel = fmt::format("large_m_int8pack_mm_{}", scalarToMetalTypeString(A));
257 const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input);
272 const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
87 …auto matmulPSO = lib.getPipelineStateForFunc("naive_matmul_" + mps::scalarToMetalTypeString(output…
72 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);73 static inline std::string scalarToMetalTypeString(const Tensor& t) { in scalarToMetalTypeString() function74 return scalarToMetalTypeString(t.scalar_type()); in scalarToMetalTypeString()
209 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {