1 namespace at {
2 namespace cuda {
3 //windows doesn't like large string literals, so split in two
4 const std::string reduction_template_0 = R"ESCAPE(
5 #define C10_HOST_DEVICE __host__ __device__
6 #define C10_DEVICE __device__
7 #if defined(__clang__) && defined(__HIP__)
8 #ifndef __forceinline__
9 #define __forceinline__ inline __attribute__((always_inline))
10 #endif
11 // until ROCm support for kernel asserts is restored
12 #define assert(expr) (static_cast<void>(0))
13 #endif
14
15 template <typename T>
16 __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
17 {
18 #if defined(__clang__) && defined(__HIP__)
19 return __shfl_down(value, delta, width);
20 #else
21 return __shfl_down_sync(mask, value, delta, width);
22 #endif
23 }
24
25
26 #if ${complex}
27 template <typename T>
28 __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
29 {
30 return std::complex<T>(
31 #if defined(__clang__) && defined(__HIP__)
32 __shfl_down(value.real(), delta, width),
33 __shfl_down(value.imag(), delta, width));
34 #else
35 __shfl_down_sync(mask, value.real(), delta, width),
36 __shfl_down_sync(mask, value.imag(), delta, width));
37 #endif
38 }
39 #endif
40
41 // aligned vector generates vectorized load/store on CUDA
42 template<typename scalar_t, int vec_size>
43 struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
44 scalar_t val[vec_size];
45 };
46
47
48 C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
49 // get GCD of num and denom using Euclid's algorithm.
50 // Can replace this with std::gcd if we ever support c++17.
51 size_t a = denominator;
52 size_t b = numerator;
53 while (b != 0) {
54 a %= b;
55 // swap(a,b)
56 size_t tmp = a;
57 a = b;
58 b = tmp;
59 }
60
61 // a is now the GCD
62 numerator /= a;
63 denominator /= a;
64 }
65
66
67
68
69 struct ReduceConfig {
70 //has to match host-side ReduceConfig in the eager code
71 static constexpr int BLOCK_X = 0;
72 static constexpr int BLOCK_Y = 1;
73 static constexpr int CTA = 2;
74
75 static constexpr int input_vec_size = 4;
76 int element_size_bytes;
77 int num_inputs;
78 int num_outputs;
79 int step_input = 1;
80 int step_output = 1;
81 int ctas_per_output = 1;
82 int input_mult[3] = {0, 0, 0};
83 int output_mult[2] = {0, 0};
84
85 int block_width;
86 int block_height;
87 int num_threads;
88
89 bool vectorize_input = false;
90 int output_vec_size = 1;
91
92 C10_HOST_DEVICE bool should_block_x_reduce() const {
93 return input_mult[BLOCK_X] != 0;
94 }
95
96 C10_HOST_DEVICE bool should_block_y_reduce() const {
97 return input_mult[BLOCK_Y] != 0;
98 }
99
100 C10_HOST_DEVICE bool should_global_reduce() const {
101 return input_mult[CTA] != 0;
102 }
103
104 C10_DEVICE bool should_store(int output_idx) const {
105 return output_idx < num_outputs &&
106 (!should_block_x_reduce() || threadIdx.x == 0) &&
107 (!should_block_y_reduce() || threadIdx.y == 0);
108 }
109
110 C10_DEVICE bool should_reduce_tail() const {
111 return (!should_block_y_reduce() || threadIdx.y == 0) &&
112 (!should_global_reduce() || blockIdx.y == 0);
113 }
114
115 C10_HOST_DEVICE int input_idx() const {
116 int lane = threadIdx.x;
117 int warp = threadIdx.y;
118 int cta2 = blockIdx.y;
119 return (lane * input_mult[BLOCK_X] +
120 warp * input_mult[BLOCK_Y] +
121 cta2 * input_mult[CTA]);
122 }
123
124 template <int output_vec_size>
125 C10_HOST_DEVICE int output_idx() const {
126 int lane = threadIdx.x;
127 int warp = threadIdx.y;
128 int cta1 = blockIdx.x;
129 return (lane * output_mult[BLOCK_X] +
130 warp * output_mult[BLOCK_Y] +
131 cta1 * step_output) * output_vec_size;
132 }
133
134 C10_DEVICE int shared_memory_offset(int offset) const {
135 return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
136 }
137
138 C10_DEVICE int staging_memory_offset(int cta2) const {
139 int offset = cta2 + blockIdx.x * gridDim.y;
140 if (!should_block_x_reduce()) {
141 offset = threadIdx.x + offset * blockDim.x;
142 }
143 return offset;
144 }
145
146
147 };
148
149
150 //TODO this will need to be different for more generic reduction functions
151 namespace reducer {
152
153 using scalar_t = ${scalar_type};
154 using arg_t = ${reduction_accum_type};
155 using out_scalar_t = ${result_type};
156
157
158 inline __device__ ${functor}
159
160 inline __device__ out_scalar_t project(arg_t arg) {
161 return (out_scalar_t) arg;
162 }
163
164 inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
165 return WARP_SHFL_DOWN(arg, offset);
166 }
167
168 inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
169 return acc;
170 }
171
172 // wrap a normal reduction that ignores the index
173 inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
174 return combine(acc, val);
175 }
176 }
177
178
179 struct ReduceJitOp {
180 using scalar_t = ${scalar_type};
181 using arg_t = ${reduction_accum_type};
182 using out_scalar_t = ${result_type};
183
184 using InputCalculator = OffsetCalculator<1>;
185 using OutputCalculator = OffsetCalculator<2>;
186
187 // static constexpr bool can_accumulate_in_output =
188 // std::is_convertible<arg_t, out_scalar_t>::value
189 // && std::is_convertible<out_scalar_t, arg_t>::value;
190
191 static constexpr int input_vec_size = ReduceConfig::input_vec_size;
192
193 arg_t ident;
194 ReduceConfig config;
195 InputCalculator input_calc;
196 OutputCalculator output_calc;
197 const void* src;
198 const char* dst[2]; //it accepts at most two destinations
199 // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
200 // output is not permissible
201 void* acc_buf;
202 // cta_buf used for accumulation between blocks during global reduction
203 void* cta_buf;
204 int* semaphores;
205 int64_t base_idx;
206 bool accumulate;
207 bool final_output;
208 int noutputs;
209
210
211 C10_DEVICE void run() const {
212 extern __shared__ char shared_memory[];
213 uint32_t output_idx = config.output_idx<${output_vec_size}>();
214 uint32_t input_idx = config.input_idx();
215 auto base_offsets1 = output_calc.get(output_idx)[1];
216
217 using arg_vec_t = Array<arg_t, ${output_vec_size}>;
218 arg_vec_t value;
219
220 if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
221 const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
222
223 value = thread_reduce<${output_vec_size}>(input_slice);
224 }
225
226 if (config.should_block_y_reduce()) {
227 value = block_y_reduce<${output_vec_size}>(value, shared_memory);
228 }
229 if (config.should_block_x_reduce()) {
230 value = block_x_reduce<${output_vec_size}>(value, shared_memory);
231 }
232
233 using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
234 using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
235 offset_vec_t base_offsets;
236 out_ptr_vec_t out;
237
238 #pragma unroll
239 for (int i = 0; i < ${output_vec_size}; i++) {
240 base_offsets[i] = output_calc.get(output_idx + i)[0];
241 out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
242 }
243
244 arg_vec_t* acc = nullptr;
245 if (acc_buf != nullptr) {
246 size_t numerator = sizeof(arg_t);
247 size_t denominator = sizeof(out_scalar_t);
248 reduce_fraction(numerator, denominator);
249 acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
250 }
251
252 if (config.should_global_reduce()) {
253 value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
254 } else if (config.should_store(output_idx)) {
255 if (accumulate) {
256 #pragma unroll
257 for (int i = 0; i < ${output_vec_size}; i++) {
258 value[i] = reducer::translate_idx(value[i], base_idx);
259 }
260 }
261
262 if (acc == nullptr) {
263 if (accumulate) {
264 value = accumulate_in_output<${output_vec_size}>(out, value);
265 }
266 if (final_output) {
267 set_results_to_output<${output_vec_size}>(value, base_offsets);
268 } else {
269 #pragma unroll
270 for (int i = 0; i < ${output_vec_size}; i++) {
271 *(out[i]) = get_accumulated_output(out[i], value[i]);
272 }
273 }
274 } else {
275 if (accumulate) {
276 #pragma unroll
277 for (int i = 0; i < ${output_vec_size}; i++) {
278 value[i] = reducer::combine((*acc)[i], value[i]);
279 }
280 }
281 if (final_output) {
282 set_results_to_output<${output_vec_size}>(value, base_offsets);
283 } else {
284 *acc = value;
285 }
286 }
287 }
288 }
289
290 template <int output_vec_size>
291 C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
292 if (config.vectorize_input) {
293 assert(output_vec_size == 1);
294 // reduce at the header of input_slice where memory is not aligned,
295 // so that thread_reduce will have an aligned memory to work on.
296 return {input_vectorized_thread_reduce_impl(data)};
297 } else {
298 uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
299 bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
300 if (is_contiguous) {
301 return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
302 } else if (input_calc.dims == 1) {
303 return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
304 } else {
305 return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
306 }
307 }
308 }
309
310 C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
311 uint32_t end = config.num_inputs;
312
313 // Handle the head of input slice where data is not aligned
314 arg_t value = ident;
315 constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
316 constexpr int align_elements = align_bytes / sizeof(scalar_t);
317 int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
318 if (shift > 0) {
319 data -= shift;
320 end += shift;
321 if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
322 value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
323 }
324 end -= align_elements;
325 data += align_elements;
326 shift = align_elements - shift;
327 }
328
329 // Do the vectorized reduction
330 using load_t = aligned_vector<scalar_t, input_vec_size>;
331
332 uint32_t idx = config.input_idx();
333 const uint32_t stride = config.step_input;
334
335 // Multiple accumulators to remove dependency between unrolled loops.
336 arg_t value_list[input_vec_size];
337 value_list[0] = value;
338
339 #pragma unroll
340 for (int i = 1; i < input_vec_size; i++) {
341 value_list[i] = ident;
342 }
343
344 scalar_t values[input_vec_size];
345
346 load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
347
348 while (idx * input_vec_size + input_vec_size - 1 < end) {
349 *values_vector = reinterpret_cast<const load_t*>(data)[idx];
350 #pragma unroll
351 for (uint32_t i = 0; i < input_vec_size; i++) {
352 value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
353 }
354 idx += stride;
355 }
356
357 // tail
358 uint32_t tail_start = end - end % input_vec_size;
359 if (config.should_reduce_tail()) {
360 int idx = tail_start + threadIdx.x;
361 if (idx < end) {
362 value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
363 }
364 }
365
366 // combine accumulators
367 #pragma unroll
368 for (int i = 1; i < input_vec_size; i++) {
369 value_list[0] = reducer::combine(value_list[0], value_list[i]);
370 }
371 return value_list[0];
372 }
373
374 template <int output_vec_size, typename offset_calc_t>
375 C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
376 uint32_t idx = config.input_idx();
377 const uint32_t end = config.num_inputs;
378 const uint32_t stride = config.step_input;
379 const int vt0=${vt0};
380
381 using arg_vec_t = Array<arg_t, output_vec_size>;
382 using load_t = aligned_vector<scalar_t, output_vec_size>;
383 const load_t* data = reinterpret_cast<const load_t*>(data_);
384
385 // Multiple accumulators to remove dependency between unrolled loops.
386 arg_vec_t value_list[vt0];
387
388 #pragma unroll
389 for (int i = 0; i < vt0; i++) {
390 #pragma unroll
391 for (int j = 0; j < output_vec_size; j++) {
392 value_list[i][j] = ident;
393 }
394 }
395
396 load_t values[vt0];
397
398 while (idx + (vt0 - 1) * stride < end) {
399 #pragma unroll
400 for (uint32_t i = 0; i < vt0; i++) {
401 values[i] = data[calc(idx + i * stride) / output_vec_size];
402 }
403 #pragma unroll
404 for (uint32_t i = 0; i < vt0; i++) {
405 #pragma unroll
406 for (uint32_t j = 0; j < output_vec_size; j++) {
407 value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
408 }
409 }
410 idx += stride * vt0;
411 }
412
413 // tail
414 int idx_ = idx;
415 #pragma unroll
416 for (uint32_t i = 0; i < vt0; i++) {
417 if (idx >= end) {
418 break;
419 }
420 values[i] = data[calc(idx) / output_vec_size];
421 idx += stride;
422 }
423 idx = idx_;
424 #pragma unroll
425 for (uint32_t i = 0; i < vt0; i++) {
426 if (idx >= end) {
427 break;
428 }
429 #pragma unroll
430 for (uint32_t j = 0; j < output_vec_size; j++) {
431 value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
432 }
433 idx += stride;
434 }
435
436 // combine accumulators
437 #pragma unroll
438 for (int i = 1; i < vt0; i++) {
439 #pragma unroll
440 for (uint32_t j = 0; j < output_vec_size; j++) {
441 value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
442 }
443 }
444 return value_list[0];
445 }
446 template <int output_vec_size>
447 C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
448 using args_vec_t = Array<arg_t, output_vec_size>;
449 int dim_x = blockDim.x;
450 args_vec_t* shared = (args_vec_t*)shared_memory;
451 if (dim_x > warpSize) {
452 int address_base = threadIdx.x + threadIdx.y*blockDim.x;
453 shared[address_base] = value;
454 for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
455 __syncthreads();
456 if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
457 args_vec_t other = shared[address_base + offset];
458 #pragma unroll
459 for (int i = 0; i < output_vec_size; i++) {
460 value[i] = reducer::combine(value[i], other[i]);
461 }
462 shared[address_base] = value;
463 }
464 }
465 dim_x = warpSize;
466 }
467
468 __syncthreads();
469
470 for (int offset = 1; offset < dim_x; offset <<= 1) {
471 #pragma unroll
472 for (int i = 0; i < output_vec_size; i++) {
473 arg_t other = reducer::warp_shfl_down(value[i], offset);
474 value[i] = reducer::combine(value[i], other);
475 }
476 }
477 return value;
478 }
479
480 template <int output_vec_size>
481 C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
482 using args_vec_t = Array<arg_t, output_vec_size>;
483 args_vec_t* shared = (args_vec_t*)shared_memory;
484 shared[config.shared_memory_offset(0)] = value;
485 for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
486 __syncthreads();
487 if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
488 args_vec_t other = shared[config.shared_memory_offset(offset)];
489 #pragma unroll
490 for (int i = 0; i < output_vec_size; i++) {
491 value[i] = reducer::combine(value[i], other[i]);
492 }
493 shared[config.shared_memory_offset(0)] = value;
494 }
495 }
496 return value;
497 }
498 )ESCAPE";
499
500 const std::string reduction_template_1 = R"ESCAPE(
501
502 C10_DEVICE bool mark_block_finished() const {
503 __shared__ bool is_last_block_done_shared;
504
505 __syncthreads();
506 if (threadIdx.x == 0 && threadIdx.y == 0) {
507 int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
508 is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
509 }
510
511 __syncthreads();
512
513 return is_last_block_done_shared;
514 }
515
516 template <int output_vec_size>
517 C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
518 Array<out_scalar_t*, output_vec_size> out,
519 Array<arg_t, output_vec_size> value
520 ) const {
521 Array<arg_t, output_vec_size> ret;
522 #pragma unroll
523 for (int i = 0; i < output_vec_size; i++) {
524 ret[i] = reducer::combine(*(out[i]), value[i]);
525 }
526 return ret;
527 }
528
529
530 C10_DEVICE out_scalar_t get_accumulated_output(
531 out_scalar_t* out, arg_t value
532 ) const {
533 assert(!final_output);
534 return (out_scalar_t)value;
535 }
536
537 template<class T>
538 C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
539 assert(noutputs == 1);
540 auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
541 *res = x;
542 }
543
544 //TODO - multi-output reduction - we won't be able to use thrust::pair
545 //just explicitly specify typed output reads/writes
546 //Currently implemented for max of two outputs
547 // template<class T1, class T2>
548 // C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
549 // if (noutputs >= 1) {
550 // auto res0 = (T1*)((char*)dst[0] + base_offset);
551 // *res0 = x.first;
552 // }
553 // if (noutputs >= 2) {
554 // // base offset is computed assuming element size being sizeof(T1), so we need to make a
555 // // correction to obtain the correct base offset
556 // auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
557 // *res1 = x.second;
558 // }
559 // }
560
561 template <int output_vec_size>
562 C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
563 assert(final_output);
564 #pragma unroll
565 for (int i = 0; i < output_vec_size; i++) {
566 set_results(reducer::project(value[i]), base_offset[i]);
567 }
568 }
569
570 template <int output_vec_size>
571 C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
572 using arg_vec_t = Array<arg_t, output_vec_size>;
573 using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
574 using offset_vec_t = Array<uint32_t, output_vec_size>;
575
576 arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
577 uint32_t output_idx = config.output_idx<output_vec_size>();
578 offset_vec_t base_offsets;
579 out_ptr_vec_t out;
580
581 #pragma unroll
582 for (int i = 0; i < output_vec_size; i++) {
583 base_offsets[i] = output_calc.get(output_idx + i)[0];
584 out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
585 }
586
587 bool should_store = config.should_store(output_idx);
588 if (should_store) {
589 uint32_t offset = config.staging_memory_offset(blockIdx.y);
590 reduce_buffer[offset] = value;
591 }
592
593 __threadfence(); // make sure writes are globally visible
594 __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
595 bool is_last_block_done = mark_block_finished();
596
597 if (is_last_block_done) {
598 __threadfence(); //complete acquire pattern
599 value = ident;
600 if (config.should_block_x_reduce()) {
601 uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
602 uint32_t step = blockDim.x * blockDim.y;
603 for (; input_offset < config.ctas_per_output; input_offset += step) {
604 uint32_t idx = config.staging_memory_offset(input_offset);
605 arg_vec_t next = reduce_buffer[idx];
606 #pragma unroll
607 for (int i = 0; i < output_vec_size; i++) {
608 value[i] = reducer::combine(value[i], next[i]);
609 }
610 }
611 } else {
612 uint32_t input_offset = threadIdx.y;
613 uint32_t step = blockDim.y;
614 for (; input_offset < config.ctas_per_output; input_offset += step) {
615 uint32_t idx = config.staging_memory_offset(input_offset);
616 arg_vec_t next = reduce_buffer[idx];
617 #pragma unroll
618 for (int i = 0; i < output_vec_size; i++) {
619 value[i] = reducer::combine(value[i], next[i]);
620 }
621 }
622 }
623 value = block_y_reduce(value, shared_memory);
624 if (config.should_block_x_reduce()) {
625 value = block_x_reduce<output_vec_size>(value, shared_memory);
626 }
627 if (should_store) {
628 if (accumulate) {
629 #pragma unroll
630 for (int i = 0; i < output_vec_size; i++) {
631 value[i] = reducer::translate_idx(value[i], base_idx);
632 }
633 }
634
635 if (acc == nullptr) {
636 if (accumulate) {
637 value = accumulate_in_output<output_vec_size>(out, value);
638 }
639 if (final_output) {
640 set_results_to_output<output_vec_size>(value, base_offsets);
641 } else {
642 #pragma unroll
643 for (int i = 0; i < output_vec_size; i++) {
644 *(out[i]) = get_accumulated_output(out[i], value[i]);
645 }
646 }
647 } else {
648 if (accumulate) {
649 #pragma unroll
650 for (int i = 0; i < output_vec_size; i++) {
651 value[i] = reducer::combine((*acc)[i], value[i]);
652 }
653 }
654 if (final_output) {
655 set_results_to_output<output_vec_size>(value, base_offsets);
656 } else {
657 *acc = value;
658 }
659 }
660 }
661 }
662
663 return value;
664 }
665 };
666
667 extern "C"
668 __launch_bounds__(${max_threads_lb}, 4)
669 __global__ void reduction_${name}_kernel(ReduceJitOp r){
670 r.run();
671 }
672 )ESCAPE";
673
674 const std::string reduction_template = reduction_template_0 + reduction_template_1;
675
676
get_reduction_template()677 const std::string &get_reduction_template() {
678 return reduction_template;
679 }
680
681 }}
682