1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33   \brief Functor performing linear combination operations used by epilogues.
34 */
35 
36 #pragma once
37 
38 #include <cuda_fp16.h>
39 
40 #include <cutlass/array.h>
41 #include <cutlass/cutlass.h>
42 #include <cutlass/epilogue/thread/activation.h>
43 #include <cutlass/functional.h>
44 #include <cutlass/numeric_conversion.h>
45 #include <cutlass/numeric_types.h>
46 
47 /////////////////////////////////////////////////////////////////////////////////////////////////
48 
49 namespace cutlass {
50 namespace epilogue {
51 namespace thread {
52 
53 /////////////////////////////////////////////////////////////////////////////////////////////////
54 
55 namespace detail {
56 
57 template <typename Element, int ElementsPerAccess>
58 struct ArrayExponential {
59   CUTLASS_HOST_DEVICE
operatorArrayExponential60   Array<Element, ElementsPerAccess> operator()(
61       Array<Element, ElementsPerAccess> const& input) const {
62     Array<Element, ElementsPerAccess> result;
63 
64     CUTLASS_PRAGMA_UNROLL
65     for (int i = 0; i < ElementsPerAccess; ++i) {
66       result[i] = expf(input[i]);
67     }
68 
69     return result;
70   }
71 };
72 
73 template <int ElementsPerAccess>
74 struct ArrayExponential<half_t, ElementsPerAccess> {
75   CUTLASS_DEVICE
76   Array<half_t, ElementsPerAccess> operator()(
77       Array<half_t, ElementsPerAccess> const& input) const {
78     Array<half_t, ElementsPerAccess> result;
79 
80     int const kVectorCount = ElementsPerAccess / 2;
81 
82     __half2 const* input_ptr =
83         reinterpret_cast<__half2 const*>(input.raw_data());
84     __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
85 
86     CUTLASS_PRAGMA_UNROLL
87     for (int i = 0; i < kVectorCount; ++i) {
88       res_ptr[i] = h2exp(input_ptr[i]);
89     }
90 
91     return result;
92   }
93 };
94 } // namespace detail
95 
96 /////////////////////////////////////////////////////////////////////////////////////////////////
97 
98 /// Applies:
99 /// output <- (input - lse).exp()
100 template <
101     typename ElementOutput_, // output
102     typename ElementLSE_, // accumulator from LSE
103     typename ElementAccumulator_, // accumulator from matmul
104     typename ElementCompute_, // intermediate compute (and exp calculation)
105     int ElementsPerAccess>
106 class ApplyLogSumExp {
107  public:
108   using ElementOutput = ElementOutput_;
109   using ElementAccumulator = ElementAccumulator_;
110   using ElementCompute = ElementCompute_;
111   using ElementLSE = ElementLSE_;
112 
113   static int const kElementsPerAccess = ElementsPerAccess;
114   static int const kCount = kElementsPerAccess;
115   static const ScaleType::Kind kScale =
116       cutlass::epilogue::thread::ScaleType::NoBetaScaling;
117 
118   using FragmentOutput = Array<ElementOutput, kCount>;
119   using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
120   using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
121   using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
122   using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
123 
124  public:
125   //
126   // Methods
127   //
128 
129   CUTLASS_HOST_DEVICE
130   ApplyLogSumExp() {}
131 
132   /// Returns true if source is needed
133   CUTLASS_HOST_DEVICE
134   bool is_source_needed() const {
135     return true;
136   }
137 
138   /// Functionally required for serial reduction in the epilogue
139   CUTLASS_HOST_DEVICE
140   void set_k_partition(int k_partition, int k_partition_count) {}
141 
142   CUTLASS_HOST_DEVICE
143   FragmentOutput operator()(
144       FragmentAccumulator const& AB,
145       FragmentLSE const& scale_unused,
146       // bias used as LSE
147       FragmentLSE const& bias) const {
148     FragmentCompute frag_AB = NumericArrayConverter<
149         ElementCompute,
150         ElementAccumulator,
151         kElementsPerAccess>()(AB);
152     FragmentCompute frag_lse_compute =
153         NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
154             bias);
155     FragmentCompute frag_compute;
156 
157     minus<FragmentCompute> minus_lse;
158     detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
159     frag_compute = minus_lse(frag_AB, frag_lse_compute);
160     frag_compute = apply_exp(frag_compute);
161 
162     return NumericArrayConverter<
163         ElementOutput,
164         ElementCompute,
165         kElementsPerAccess>()(frag_compute);
166   }
167 };
168 
169 /////////////////////////////////////////////////////////////////////////////////////////////////
170 
171 } // namespace thread
172 } // namespace epilogue
173 } // namespace cutlass
174 
175 /////////////////////////////////////////////////////////////////////////////////////////////////
176