In a high performance gemm kernel, the number of requests and data transfers to and from lds respectively are pretty high (around 8 thousand). In this post, we describe a way to decrease them (to around 3 thousand).
Writing a high performant GPU kernel is a challenge in itself. A simple mistake like putting instructions in wrong order can bring down the performance drastically. Fortunately, we are not going to discuss about it. If you are interested, follow this article.
If you already have written a gemm kernel based off the article, we futher try to optimize it by decreasing the number of lds requests made. The idea is, fewer the requests, fewer data transfers to registers.
The article explains how to implement different block-sizes effectively (I assume you read the article); we take block size of 128x128 for analysis. In order to operate on 128x128 block, a 128x8 tile is written to lds from global memory of both A and B matrices. The 256 workitems are partioned into 16x16 group (lets call them workitems along x dimension and y dimension) where they load data from A across x dimension and B across y dimension. Each workitem loads 2 float4s from A and B, where each float4 of a tile are 64 floats apart.
So, the total number of lds read requests for a row of tile are 16 (workitems across x dimension) * 16 (workitems across y dimension) * 2 (2 float4s from A) * 2 (2 float4s from B) = 1024 requests. For the whole tile is 1024 * 8 (unroll factor) = 8192 requests
Is there a way we can reduce the requests? Yes, there is. Let me introduce you to a feature called DPP
What is DPP?
DPP (Data Parallel Primitives) are instruction modifiers which can do op across lanes without the cost of transfer. As we use very specific set of dpp instructions, if you are interested in learning more, this article provides good introduction to different functionality. We specifically focus on quad permute modifier.
Quad Permute modifiers can access data from 4 neighboring lanes. Here is how a quad permute modifier mac instruction looks like.
The list of numbers described by
quad_perm identifier are the lane numbers which
v2 will be accessed from; in other words, the instruction does
v_mac_f32 v0, v1, v2[quad_perm[current_lane]].
For example, we have
v_mac_f32 v0, v1, v2 quad_perm:[2, 2, 0, 1]instruction, and
[1.0f, 2.0f, 3.0f, 4.0f]across 4 lanes.
macinstruction will looks like 3rd column in the table below.
|Lane Id (l)||Quad Perm Id (q)||v2 (v2[q[l]])|
Note that, the modifier fetches the data from
v2 of the neighboring lanes but does not replace the current value in
How to use DPP with GEMM?
In the first iteration of optimization, instead of loading 2 float4s from A or B sub-tile per workitem, we load first float4 to odd workitem and second float4 to even workitem. Then, we transform your mac operations to access data from odd or even lanes.
; Before optimization ; a0, a1, b0, b1 -> each input float ; c0, c1, c2, c3 -> each output float ds_read_b32 a0, sA ds_read_b32 a1, sA offset:256 ; 64 * 4 bytes ds_read_b32 b0, sB ds_read_b32 b1, sB offset:256 ; wait for lds reads s_waitcnt lgkmcnt(0) v_mac_f32 c0, a0, b0 v_mac_f32 c1, a0, b1 v_mac_f32 c2, a1, b0 v_mac_f32 c3, a1, b1 ; we can represent these instructions using quad_perm as v_mac_f32 c0, a0, b0 quad_perm:[0, 1, 2, 3] v_mac_f32 c1, a0, b1 quad_perm:[0, 1, 2, 3] v_mac_f32 c2, a1, b0 quad_perm:[0, 1, 2, 3] v_mac_f32 c3, a1, b1 quad_perm:[0, 1, 2, 3]
We assume we are loading just one b per workitem.
; After optimization ; a0, a1, b -> each input float ; c0, c1, c2, c3 -> each output float ; we assume b0 is in even workitem register and b1 in odd workitem register ds_read_b32 a0, sA ds_read_b32 a1, sA offset:256 ds_read_b32 b, newSB ; we change lds memory pointer for each workitem to load odd/even float4 ; wait for lds reads s_waitcnt lgkmcnt(0) v_mac_f32 c0, a0, b quad_perm:[0, 0, 2, 2] v_mac_f32 c1, a0, b quad_perm:[1, 1, 3, 3] v_mac_f32 c2, a1, b quad_perm:[0, 0, 2, 2] v_mac_f32 c3, a1, b quad_perm:[1, 1, 3, 3]
The number of mac ops remained the same but we decreased the number of lds requests by a quarter.
Total lds ops = (2 (number of lds loads of A) * 256 (number of workitems) * 8 (unroll factor)) + (1 (number of lds loads of B) * 256 (number of workitems) * 8 (unroll factor)) = 6144
We can further optimize the number of loads by loading 2 rows from lds instead of one. In this case, we may have to double the number of mac ops as we are operating on new row which can cause instruction cache misses. But, makes the code more compute bound (which is good).
Total lds ops = (2 (number of lds loads of A) * 256 (number of workitems) * 8 (unroll factor)) + (1 (number of lds loads of B) * 256 (number of workitems) * 4 (unroll factor)) = 5120 requests.
By using DPP modifiers, we brought down the number of lds requests from 8192 to 5120 (Saved 37.5% of lds requests).