1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Functor performing linear combination operations used by epilogues. 34 */ 35 36 #pragma once 37 38 #include <cuda_fp16.h> 39 40 #include <cutlass/array.h> 41 #include <cutlass/cutlass.h> 42 #include <cutlass/epilogue/thread/activation.h> 43 #include <cutlass/functional.h> 44 #include <cutlass/numeric_conversion.h> 45 #include <cutlass/numeric_types.h> 46 47 ///////////////////////////////////////////////////////////////////////////////////////////////// 48 49 namespace cutlass { 50 namespace epilogue { 51 namespace thread { 52 53 ///////////////////////////////////////////////////////////////////////////////////////////////// 54 55 namespace detail { 56 57 template <typename Element, int ElementsPerAccess> 58 struct ArrayExponential { 59 CUTLASS_HOST_DEVICE operatorArrayExponential60 Array<Element, ElementsPerAccess> operator()( 61 Array<Element, ElementsPerAccess> const& input) const { 62 Array<Element, ElementsPerAccess> result; 63 64 CUTLASS_PRAGMA_UNROLL 65 for (int i = 0; i < ElementsPerAccess; ++i) { 66 result[i] = expf(input[i]); 67 } 68 69 return result; 70 } 71 }; 72 73 template <int ElementsPerAccess> 74 struct ArrayExponential<half_t, ElementsPerAccess> { 75 CUTLASS_DEVICE 76 Array<half_t, ElementsPerAccess> operator()( 77 Array<half_t, ElementsPerAccess> const& input) const { 78 Array<half_t, ElementsPerAccess> result; 79 80 int const kVectorCount = ElementsPerAccess / 2; 81 82 __half2 const* input_ptr = 83 reinterpret_cast<__half2 const*>(input.raw_data()); 84 __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); 85 86 CUTLASS_PRAGMA_UNROLL 87 for (int i = 0; i < kVectorCount; ++i) { 88 res_ptr[i] = h2exp(input_ptr[i]); 89 } 90 91 return result; 92 } 93 }; 94 } // namespace detail 95 96 ///////////////////////////////////////////////////////////////////////////////////////////////// 97 98 /// Applies: 99 /// output <- (input - lse).exp() 100 template < 101 typename ElementOutput_, // output 102 typename ElementLSE_, // accumulator from LSE 103 typename ElementAccumulator_, // accumulator from matmul 104 typename ElementCompute_, // intermediate compute (and exp calculation) 105 int ElementsPerAccess> 106 class ApplyLogSumExp { 107 public: 108 using ElementOutput = ElementOutput_; 109 using ElementAccumulator = ElementAccumulator_; 110 using ElementCompute = ElementCompute_; 111 using ElementLSE = ElementLSE_; 112 113 static int const kElementsPerAccess = ElementsPerAccess; 114 static int const kCount = kElementsPerAccess; 115 static const ScaleType::Kind kScale = 116 cutlass::epilogue::thread::ScaleType::NoBetaScaling; 117 118 using FragmentOutput = Array<ElementOutput, kCount>; 119 using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>; 120 using FragmentCompute = Array<ElementCompute, kElementsPerAccess>; 121 using FragmentLSE = Array<ElementLSE, kElementsPerAccess>; 122 using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h 123 124 public: 125 // 126 // Methods 127 // 128 129 CUTLASS_HOST_DEVICE 130 ApplyLogSumExp() {} 131 132 /// Returns true if source is needed 133 CUTLASS_HOST_DEVICE 134 bool is_source_needed() const { 135 return true; 136 } 137 138 /// Functionally required for serial reduction in the epilogue 139 CUTLASS_HOST_DEVICE 140 void set_k_partition(int k_partition, int k_partition_count) {} 141 142 CUTLASS_HOST_DEVICE 143 FragmentOutput operator()( 144 FragmentAccumulator const& AB, 145 FragmentLSE const& scale_unused, 146 // bias used as LSE 147 FragmentLSE const& bias) const { 148 FragmentCompute frag_AB = NumericArrayConverter< 149 ElementCompute, 150 ElementAccumulator, 151 kElementsPerAccess>()(AB); 152 FragmentCompute frag_lse_compute = 153 NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()( 154 bias); 155 FragmentCompute frag_compute; 156 157 minus<FragmentCompute> minus_lse; 158 detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp; 159 frag_compute = minus_lse(frag_AB, frag_lse_compute); 160 frag_compute = apply_exp(frag_compute); 161 162 return NumericArrayConverter< 163 ElementOutput, 164 ElementCompute, 165 kElementsPerAccess>()(frag_compute); 166 } 167 }; 168 169 ///////////////////////////////////////////////////////////////////////////////////////////////// 170 171 } // namespace thread 172 } // namespace epilogue 173 } // namespace cutlass 174 175 ///////////////////////////////////////////////////////////////////////////////////////////////// 176