# mypy: allow-untyped-defs import argparse import sys sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1} def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets): def compute(regid, InType, use_weights, isa, prefetch): code = [] if InType == "float": code.append( " vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);" # noqa % (regid, regid, regid) ) elif InType == "at::Half": code.append( " vop%d = _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtph_ps(\n" " _mm_loadu_si128(reinterpret_cast(ip + (%d)))),\n" # noqa " vop%d);" % (regid, regid, regid) ) elif InType == "at::BFloat16": code.append( " vop%d = _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_castsi256_ps(_mm256_slli_epi32(\n" " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" " reinterpret_cast(ip + (%d)))),\n" " 16)),\n" # noqa " vop%d);" % (regid, regid, regid) ) elif InType == "uint8_t": code.append( " vop%d = _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n" " _mm_loadl_epi64(reinterpret_cast(ip + (%d))))),\n" # noqa " _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid) ) else: assert False if prefetch: code.append( " _mm_prefetch(\n" " reinterpret_cast(&ip_next_T0[%d]), _MM_HINT_T0);" % (regid) ) else: code.append( " // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid) ) return code code = [] code.append(" // unrolling " + str(uf) + " times") if use_offsets: code.append( " for (" + IndexType + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" ) else: code.append( " for (" + IndexType + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" ) code.append(" " + OutType + "* op = &out[rangeIndex * block_size];") for i in range(0, uf): j = 8 * i code.append(" __m256 vop" + str(j) + " = _mm256_setzero_ps();") # inner loop if use_offsets: code.append( " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n" + " return false;\n" + " }" ) code.append("""\ int64_t end_offset = offsets[rangeIndex + 1]; int64_t length = end_offset - offsets[rangeIndex];""") code.append( " for (" + "int64_t" + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa ) else: code.append( " if (dataInd + lengths[rangeIndex] > index_size) {\n" + " return false;\n" + " }" ) code.append( " for (" + IndexType + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( " if (idx < 0 || idx >= data_size) {\n" + " return false;\n" + " }" ) if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa ) code.append(" }") if fused: code.append( " const float* scale_bias = reinterpret_cast(\n" " &input[idx * fused_block_size + block_size]);" ) code.append(" bio = wgt * scale_bias[1];") code.append(" wgt = wgt * scale_bias[0];") else: code.append(" bio = wgt * scale_bias[2 * idx + 1];") code.append(" wgt = wgt * scale_bias[2 * idx];") code.append(" __m256 vbio = _mm256_set1_ps(bio);") else: code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa ) code.append(" }") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) code.append( " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" " ? (dataInd + prefdist_T0)\n" " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" " : dataInd;".format( IndexType ) ) code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") code.append( " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" + " return false;\n" + " }" ) code.append( " const {}* ip_next_T0 = " "&input[idx_pref_T0 * fused_block_size];".format(InType) ) for i in range(0, uf): j = 8 * i cachelinesize = 64 byteoffset = sizeof[InType] * j prefetch = (byteoffset % cachelinesize) == 0 code.extend(compute(j, InType, use_weights, isa, prefetch)) code.append(" }") if use_offsets: code.append(" if (!normalize_by_lengths || length == 0) {") else: code.append(" if (!normalize_by_lengths || lengths[rangeIndex] == 0) {") for i in range(0, uf): j = 8 * i code.append(" _mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");") code.append(" } else {") # inv of length if use_offsets: code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / length);") else: code.append(" __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") for i in range(0, uf): j = 8 * i code.append( " _mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));" ) code.append(" }") code.append(" }") return code def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets): def compute(InType, use_weights, isa): code = [] if InType == "float": code.append( " _mm256_storeu_ps(\n" " &op[j],\n" " _mm256_fmadd_ps(\n" " vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));" # noqa ) elif InType == "at::Half": code.append( " _mm256_storeu_ps(\n" " &op[j],\n" " _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtph_ps(_mm_loadu_si128(\n" " reinterpret_cast(&ip[j]))),\n" " _mm256_loadu_ps(&op[j])));" ) elif InType == "at::BFloat16": code.append( " _mm256_storeu_ps(\n" " &op[j],\n" " _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_castsi256_ps(_mm256_slli_epi32(\n" " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" " reinterpret_cast(&ip[j]))),\n" " 16)),\n" " _mm256_loadu_ps(&op[j])));" ) elif InType == "uint8_t": code.append( " _mm256_storeu_ps(\n" " &op[j],\n" " _mm256_fmadd_ps(\n" " vwgt,\n" " _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(\n" # noqa " reinterpret_cast(&ip[j])))),\n" " _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));" ) else: assert False code.append( " _mm_prefetch(\n" " reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0);" ) return code code = [] if InType == "at::Half": code.append(" alignas(64) at::Half vtmp1[8] = {0};") if InType == "at::BFloat16": code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};") if use_offsets: code.append( " for (" + IndexType + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" ) else: code.append( " for (" + IndexType + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {" ) code.append(" " + OutType + "* op = &out[rangeIndex * block_size];") # initialize to 0 code.append(" int64_t j = 0;") code.append(" for (; j + 8 <= block_size; j += 8) {") code.append(" _mm256_storeu_ps(op + j, _mm256_setzero_ps());") code.append(" }") code.append(" for (; j < block_size; j++) {") code.append(" op[j] = 0.0f;") code.append(" }") # inner loop if use_offsets: code.append( " if (dataInd != offsets[rangeIndex] - offsets[0]) {\n" + " return false;\n" + " }" ) code.append("""\ int64_t end_offset = offsets[rangeIndex + 1]; int64_t length = end_offset - offsets[rangeIndex];""") code.append( " for (" + "int64_t" + " start = dataInd; dataInd < end_offset - offsets[0];\n ++dataInd) {" # noqa ) else: code.append( " if (dataInd + lengths[rangeIndex] > index_size) {\n" + " return false;\n" + " }" ) code.append( " for (" + IndexType + " start = dataInd; dataInd < start + lengths[rangeIndex];\n ++dataInd) {" # noqa ) code.append(" const " + IndexType + " idx = indices[dataInd];") code.append( " if (idx < 0 || idx >= data_size) {\n" + " return false;\n" + " }" ) if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa ) code.append(" }") if fused: code.append( " const float* scale_bias = reinterpret_cast(\n" " &input[idx * fused_block_size + block_size]);" ) code.append(" bio = wgt * scale_bias[1];") code.append(" wgt = wgt * scale_bias[0];") else: code.append(" bio = wgt * scale_bias[2 * idx + 1];") code.append(" wgt = wgt * scale_bias[2 * idx];") code.append(" __m256 vbio = _mm256_set1_ps(bio);") else: code.append(" " + OutType + " wgt = 1.f;") code.append(" if (weights) {") code.append( " wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];" # noqa ) code.append(" }") code.append(" __m256 vwgt = _mm256_set1_ps(wgt);") code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) code.append( " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" " ? (dataInd + prefdist_T0)\n" " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" " : dataInd;".format( IndexType ) ) code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];") code.append( " if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {\n" + " return false;\n" + " }" ) code.append( " const {}* ip_next_T0 = " "&input[idx_pref_T0 * fused_block_size];".format(InType) ) # compute and store main loop code.append(" j = 0;") code.append(" for (; j + 8 <= block_size; j += 8) {") code.extend(compute(InType, use_weights, isa)) code.append(" }") # leftover code.append(" for (; j < block_size; j++) {") if InType == "float": code.append(" op[j] = std::fma(wgt, ip[j], op[j]);") elif InType == "at::Half": code.append(" vtmp1[0] = ip[j];") code.append( " __m256 vtmp2 =\n" " _mm256_cvtph_ps(*(reinterpret_cast(vtmp1)));" ) code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") elif InType == "at::BFloat16": code.append(" vtmp1[0] = ip[j];") code.append( " __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n" " _mm256_cvtepu16_epi32(*(reinterpret_cast(vtmp1))),\n" " 16));" ) code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") elif InType == "uint8_t": code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);") else: assert False code.append(" }") code.append(" }") if use_offsets: code.append(" if (normalize_by_lengths && length) {") code.append(" float len_inv = 1.0f / length;") else: code.append(" if (normalize_by_lengths && lengths[rangeIndex]) {") code.append(" float len_inv = 1.0f / lengths[rangeIndex];") code.append(" __m256 vlen_inv = _mm256_set1_ps(len_inv);") code.append(" j = 0;") code.append(" for (; j + 8 <= block_size; j += 8) {") code.append( " _mm256_storeu_ps(\n" " &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));" ) code.append(" }") code.append(" for (; j < block_size; j++) {") code.append(" op[j] = len_inv * op[j];") code.append(" }") code.append(" }") code.append(" }") return code # start main code parser = argparse.ArgumentParser() parser.add_argument("-f", "--filename", help="file name") parser.add_argument("--fused", action="store_true") parser.add_argument("--use-offsets", action="store_true") opts = parser.parse_args() if opts.filename: filename = opts.filename elif opts.fused: if opts.use_offsets: filename = "embedding_lookup_fused_8bit_rowwise_idx_avx2.cc" else: filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc" else: if opts.use_offsets: filename = "embedding_lookup_idx_avx2.cc" else: filename = "embedding_lookup_avx2.cc" options = [ ["int32_t", "int", "float", "float", "float", "float"], ["int64_t", "int64_t", "float", "float", "float", "float"], ["int32_t", "int", "half", "at::Half", "float", "float"], ["int64_t", "int64_t", "half", "at::Half", "float", "float"], ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"], ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"], ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], ] code = [] # includes code.append("//// --------------------------") code.append("//// ATTENTION:") code.append("//// THIS CODE IS AUTOGENERATED") code.append("//// BY {}".format(sys.argv[0])) code.append("//// DO NOT MODIFY!!!") code.append("//// --------------------------\n") code.append("#include ") code.append("#include ") code.append("#include ") code.append("namespace caffe2 {\n") for o in options: [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o prefix = "Fused8BitRowwise" if opts.fused else "" code.append("template ") if opts.use_offsets: fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format( prefix, IndexTypeName, InTypeName, OutTypeName ) else: fn_base = "{}EmbeddingLookup_{}_{}_{}".format( prefix, IndexTypeName, InTypeName, OutTypeName ) suffix = "__avx2_fma" fn = "static bool " + fn_base + suffix code.append(fn + "(") args = [] args.append(" const int64_t block_size,") args.append(" const int64_t output_size,") args.append(" const int64_t index_size,") args.append(" const int64_t data_size,") args.append(" const " + InType + "* input,") args.append(" const " + IndexType + "* indices,") if opts.use_offsets: args.append(" const " + IndexType + "* offsets,") else: args.append(" const int* lengths,") args.append(" const float* weights,") if not opts.fused: args.append(" const float* scale_bias,") args.append(" bool normalize_by_lengths,") args.append(" " + OutType + "* out) {") code += args code.append(" const " + IndexType + " prefdist_T0 = 16;") code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)") # block_size is the number of elements and fused_block_size is the size of # an entire row, including scale and bias. offset = (8 // sizeof[InType]) if opts.fused else 0 code.append( " const {} fused_block_size = block_size + {};".format(IndexType, offset) ) if opts.use_offsets: code.append(" int64_t dataInd = 0;") else: code.append(" " + IndexType + " dataInd = 0;") # code.append("printf(\"calling " + fn + "\\n\");"); code.append(" if (block_size == 128) {") code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" } else if (block_size == 64) {") code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" } else if (block_size == 32) {") code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" } else if (block_size == 16) {") code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" } else {") code.append(" // generic code") code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)") code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" }") code.append(" return dataInd == index_size;") code.append("}") for is_weight_positional in ["false", "true"]: code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") code += args # Resolve the Lint warnings: Limit of 80 characters in one line. extra_space = "\n " ret_string = " return " + fn_base + suffix + "<" + is_weight_positional + ">(" if len(ret_string) <= 80: code.append(ret_string) else: code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") code.append(" block_size,") code.append(" output_size,") code.append(" index_size,") code.append(" data_size,") code.append(" input,") code.append(" indices,") if opts.use_offsets: code.append(" offsets,") else: code.append(" lengths,") code.append(" weights,") if not opts.fused: code.append(" scale_bias,") code.append(" normalize_by_lengths,") code.append(" out);") code.append("}") code.append("") code.append("} // namespace caffe2") with open(filename, "w") as fout: for c in code: # print(c, file = fout) fout.write(c + "\n") print("Created " + filename)