ORippler commited on
Commit
c768824
·
1 Parent(s): c8284f2

CUDA: Optimize `reduce_rows_f32` kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n (llama/15132)

Browse files

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
As a consequence of bigger threadblocks, do 2-step summation, using
shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model | Nrow SM Count Multiple |
| ----------- | ----------- |
| RTX 4000 SFF ADA | 2.0x |
| RTX 6000 ADA | 2.5x |
| RTX PRO 6000 Blackwell Max-Q | 3.04x |
| RTX PRO 4500 Blackwell | 3.15x |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by

@JohannesGaessler
, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by

@JohannesGaessler


* Unindent Preprocessor directives

See
https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506

ggml/src/ggml-cuda/common.cuh CHANGED
@@ -87,6 +87,10 @@
87
  #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
88
  #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
89
 
 
 
 
 
90
  #ifdef __CUDA_ARCH_LIST__
91
  constexpr bool ggml_cuda_has_arch_impl(int) {
92
  return false;
@@ -420,26 +424,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
420
  #endif // FP16_AVAILABLE
421
  }
422
 
423
- // Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
424
- template<bool norm>
425
- static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
426
- const int row = blockIdx.x;
427
- const int col = threadIdx.x;
428
-
429
- float sum = 0.0f;
430
- for (int i = col; i < ncols; i += blockDim.x) {
431
- sum += x[row * ncols + i];
432
- }
433
-
434
- sum = warp_reduce_sum(sum);
435
-
436
- if (col != 0) {
437
- return;
438
- }
439
-
440
- dst[row] = norm ? sum / ncols : sum;
441
- }
442
-
443
  template<int width = WARP_SIZE>
444
  static __device__ __forceinline__ int warp_reduce_all(int x) {
445
  #ifdef GGML_USE_HIP
 
87
  #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
88
  #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
89
 
90
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
91
+ # define GGML_CUDA_USE_CUB
92
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
93
+
94
  #ifdef __CUDA_ARCH_LIST__
95
  constexpr bool ggml_cuda_has_arch_impl(int) {
96
  return false;
 
424
  #endif // FP16_AVAILABLE
425
  }
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  template<int width = WARP_SIZE>
428
  static __device__ __forceinline__ int warp_reduce_all(int x) {
429
  #ifdef GGML_USE_HIP
ggml/src/ggml-cuda/mean.cu CHANGED
@@ -1,4 +1,14 @@
1
  #include "mean.cuh"
 
 
 
 
 
 
 
 
 
 
2
 
3
  void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4
  const ggml_tensor * src0 = dst->src[0];
@@ -13,7 +23,45 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
  const int64_t ncols = src0->ne[0];
14
  const int64_t nrows = ggml_nrows(src0);
15
 
16
- const dim3 block_dims(WARP_SIZE, 1, 1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  const dim3 block_nums(nrows, 1, 1);
18
- reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
 
 
 
 
 
 
 
 
 
19
  }
 
1
  #include "mean.cuh"
2
+ #include "reduce_rows.cuh"
3
+
4
+ #ifdef GGML_CUDA_USE_CUB
5
+ #include <cub/cub.cuh>
6
+ using namespace cub;
7
+ #endif // GGML_CUDA_USE_CUB
8
+
9
+ template <typename T> __global__ void divide_by_count(T * result, size_t count) {
10
+ *result /= static_cast<T>(count);
11
+ }
12
 
13
  void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
14
  const ggml_tensor * src0 = dst->src[0];
 
23
  const int64_t ncols = src0->ne[0];
24
  const int64_t nrows = ggml_nrows(src0);
25
 
26
+ // Special case for reducing vectors
27
+ #ifdef GGML_CUDA_USE_CUB
28
+ cudaStreamCaptureStatus iscapturing;
29
+ CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
30
+ if ((nrows == 1) &&
31
+ // CUDA_GRAPHS_DISABLED
32
+ ((ncols > 65536) &&
33
+ ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
34
+ ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
35
+ ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
36
+ // CUDA_GRAPHS ENABLED
37
+ ((ncols > 32768) &&
38
+ !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
39
+ ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
40
+ ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
41
+ // Single row - use device-wide reduction
42
+ size_t tmp_size = 0;
43
+ ggml_cuda_pool & pool = ctx.pool();
44
+
45
+ DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
46
+
47
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
48
+ DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
49
+
50
+ // Divide by ncols
51
+ divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
52
+ return;
53
+ }
54
+ #endif
55
+
56
  const dim3 block_nums(nrows, 1, 1);
57
+
58
+ const int id = ggml_cuda_get_device();
59
+ const int nsm = ggml_cuda_info().devices[id].nsm;
60
+ if ((nrows / nsm) < 2) {
61
+ const dim3 block_dims(512, 1, 1);
62
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
63
+ } else {
64
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
65
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
66
+ }
67
  }
ggml/src/ggml-cuda/reduce_rows.cuh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
4
+ template <bool norm>
5
+ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
6
+ const int row = blockIdx.x;
7
+ const int col = threadIdx.x;
8
+
9
+ float sum = 0.0f;
10
+ const int num_unroll = 8;
11
+ float temp[num_unroll];
12
+ float sum_temp[num_unroll] = { 0.0f };
13
+ for (int i = col; i < ncols;) {
14
+ for (int j = 0; j < num_unroll; ++j) {
15
+ if (i < ncols) {
16
+ temp[j] = x[row * ncols + i];
17
+ } else {
18
+ temp[j] = 0;
19
+ }
20
+ i += blockDim.x;
21
+ }
22
+ for (int j = 0; j < num_unroll; ++j) {
23
+ sum_temp[j] += temp[j];
24
+ }
25
+ }
26
+ for (int j = 0; j < num_unroll; ++j) {
27
+ sum += sum_temp[j];
28
+ }
29
+
30
+ // sum up partial sums
31
+ sum = warp_reduce_sum(sum);
32
+ if (blockDim.x > WARP_SIZE) {
33
+ assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
34
+ __shared__ float s_sum[32];
35
+ const int warp_id = threadIdx.x / WARP_SIZE;
36
+ const int lane_id = threadIdx.x % WARP_SIZE;
37
+ if (lane_id == 0) {
38
+ s_sum[warp_id] = sum;
39
+ }
40
+ __syncthreads();
41
+ sum = 0.0f;
42
+ if (lane_id < (blockDim.x / WARP_SIZE)) {
43
+ sum = s_sum[lane_id];
44
+ }
45
+ sum = warp_reduce_sum(sum);
46
+ }
47
+
48
+ if (col != 0) {
49
+ return;
50
+ }
51
+
52
+ dst[row] = norm ? sum / ncols : sum;
53
+ }
ggml/src/ggml-cuda/sum.cu CHANGED
@@ -1,19 +1,15 @@
1
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2
- #define USE_CUB
3
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4
 
5
- #ifdef USE_CUB
6
  #include <cub/cub.cuh>
7
  using namespace cub;
8
- #endif // USE_CUB
9
-
10
- #include "sumrows.cuh"
11
- #include "sum.cuh"
12
 
13
  #include <cstdint>
14
 
15
  void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
16
- #ifdef USE_CUB
17
  size_t tmp_size = 0;
18
  DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
19
  ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
@@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
23
  // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
24
  sum_rows_f32_cuda(x, dst, ne, 1, stream);
25
  GGML_UNUSED(pool);
26
- #endif // USE_CUB
27
  }
28
 
29
  void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
1
+ #include "sum.cuh"
2
+ #include "sumrows.cuh"
 
3
 
4
+ #ifdef GGML_CUDA_USE_CUB
5
  #include <cub/cub.cuh>
6
  using namespace cub;
7
+ #endif // GGML_CUDA_USE_CUB
 
 
 
8
 
9
  #include <cstdint>
10
 
11
  void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
12
+ #ifdef GGML_CUDA_USE_CUB
13
  size_t tmp_size = 0;
14
  DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
15
  ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
 
19
  // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
20
  sum_rows_f32_cuda(x, dst, ne, 1, stream);
21
  GGML_UNUSED(pool);
22
+ #endif // GGML_CUDA_USE_CUB
23
  }
24
 
25
  void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml/src/ggml-cuda/sumrows.cu CHANGED
@@ -1,9 +1,17 @@
 
1
  #include "sumrows.cuh"
2
 
3
  void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
4
- const dim3 block_dims(WARP_SIZE, 1, 1);
 
5
  const dim3 block_nums(nrows, 1, 1);
6
- reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 
 
 
 
 
 
7
  }
8
 
9
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19
  const int64_t ncols = src0->ne[0];
20
  const int64_t nrows = ggml_nrows(src0);
21
 
22
- const dim3 block_dims(WARP_SIZE, 1, 1);
23
  const dim3 block_nums(nrows, 1, 1);
24
 
25
- reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
 
 
 
 
 
 
 
 
 
 
26
  }
 
1
+ #include "reduce_rows.cuh"
2
  #include "sumrows.cuh"
3
 
4
  void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5
+ const int id = ggml_cuda_get_device();
6
+ const int nsm = ggml_cuda_info().devices[id].nsm;
7
  const dim3 block_nums(nrows, 1, 1);
8
+ if ((nrows / nsm) < 2) {
9
+ const dim3 block_dims(512, 1, 1);
10
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
11
+ } else {
12
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
13
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
14
+ }
15
  }
16
 
17
  void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
27
  const int64_t ncols = src0->ne[0];
28
  const int64_t nrows = ggml_nrows(src0);
29
 
 
30
  const dim3 block_nums(nrows, 1, 1);
31
 
32
+ const int id = ggml_cuda_get_device();
33
+ const int nsm = ggml_cuda_info().devices[id].nsm;
34
+ if ((nrows / nsm) < 2) {
35
+ // Increase num threads to 512 for small nrows to better hide the latency
36
+ const dim3 block_dims(512, 1, 1);
37
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
38
+ } else {
39
+ // Enough active SMs to hide latency, use smaller blocks to allow better scheduling
40
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
41
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
42
+ }
43
  }