Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 66 additions & 35 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@
/*
* conv2d_gemm: GEMM step of im2col-backed conv2d.
*
* Reads the im2col'd input produced by conv2d_im2col.glsl as a 2D matrix
* of shape [M, K_total] (M = H_out * W_out, K_total = Kh*Kw*Cin_padded)
* and writes the conv2d output as texture3D channels-packed
* logical shape [1, C_out, H_out, W_out].
* Reads one tile of the im2col'd input produced by conv2d_im2col.glsl — a 2D
* matrix of shape [M_TILE, K_total] holding OH_TILE output-height rows
* (M_TILE = OH_TILE * W_out, K_total = Kh*Kw*Cin_padded) starting at output
* row OH_OFFSET — and writes the corresponding output rows as texture3D
* channels-packed, logical shape [1, C_out, H_out, W_out]. The full im2col
* matrix is processed OH_TILE rows per dispatch so the scratch tensor is bounded
* to a fixed byte budget regardless of resolution; with tiling disabled the
* caller passes OH_OFFSET = 0 and OH_TILE = H_out (one dispatch covers all M).
*
* The im2col input can be any of:
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for row m.
* IN_STORAGE=texture2d codegen.
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
* for output spatial position (oh, ow). Used when M would exceed
* max_texture2d_dim. IN_STORAGE=texture3d codegen.
* - buffer: vec4 at offset m*K4 + k4, same K packing.
* The im2col input tile can be any of:
* - texture2d, width-packed: texel at (k4, r) holds 4 K values for tile-local
* row r. IN_STORAGE=texture2d codegen.
* - texture3d, channels-packed: texel at (ow, oh_local, k4) holds 4 K values
* for tile-local row r = oh_local * W_out + ow. Used when the per-tile 2D
* extent (M_TILE = OH_TILE * W_out) would exceed max_texture2d_dim — rare,
* since OH_TILE is capped by the scratch byte budget. IN_STORAGE=texture3d
* codegen.
* - buffer: vec4 at offset r*K4 + k4, same K packing.
* IN_STORAGE=buffer codegen.
*
* The matmul interpretation is:
* out[m, n] = sum_k im2col[m, k] * weight[n, k] + bias[n]
* with M = H_out * W_out, K = K_total, N = C_out.
* The matmul interpretation (over this tile's rows) is:
* out[r, n] = sum_k im2col[r, k] * weight[n, k] + bias[n]
* with K = K_total, N = C_out, and r the tile-local row whose global output
* spatial position is (OH_OFFSET + r / W_out, r % W_out).
*/

#version 450 core
Expand Down Expand Up @@ -85,14 +92,23 @@ ${layout_declare_ubo(B, "ivec4", "out_sizes")}
// dims), so it is safe to bake at build time even under dynamic shapes.
// M = H_out * W_out IS shape-dependent, so it is derived at runtime from the
// refreshed out_sizes UBO in main() rather than read from here.
//
// This dispatch consumes one tile of the im2col matrix: OH_TILE output-height
// rows starting at output-height row OH_OFFSET. The im2col scratch (t_in) holds
// OH_TILE * W_out tile-local rows; the GEMM reads tile-local rows and writes the
// output at the corresponding global spatial position (OH_OFFSET + oh_local,
// ow). OH_OFFSET / OH_TILE are shape-independent (fixed at build time); the
// global W_out / H_out come from the refreshed out_sizes UBO.
layout(push_constant) uniform restrict Block {
ivec4 gemm_dims; // (K4_total, _unused, _unused, _unused)
ivec4 gemm_dims; // (K4_total, OH_OFFSET, OH_TILE, _unused)
vec4 clamp_vals; // (out_min, out_max, _unused, _unused)
};

#define K4_TOTAL gemm_dims.x
#define OUT_MIN clamp_vals.x
#define OUT_MAX clamp_vals.y
#define K4_TOTAL gemm_dims.x
#define OH_OFFSET gemm_dims.y
#define OH_TILE gemm_dims.z
#define OUT_MIN clamp_vals.x
#define OUT_MAX clamp_vals.y

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -104,30 +120,32 @@ ${layout_declare_spec_const(C, "int", "activation_type", "0")}

/*
* Load TILE_M rows × TILE_K4 K-tiles of the im2col'd input.
* The im2col output is a contiguous (M, K_total/4) matrix of vec4s, so the
* load is a plain 2D fetch — no spatial decomposition.
* The im2col scratch holds M_TILE tile-local rows in a contiguous
* (M_TILE, K_total/4) matrix of vec4s; row here is the tile-local index, so the
* load is a plain 2D fetch — no spatial decomposition. (The output store, not
* this load, maps the tile-local row to its global spatial position.)
*/
void load_input_tile_with_checks(
out FPInputTile tile,
const int k4_start,
const int m_start,
const int K4,
const int M,
const int M_TILE,
const int W_out) {
// W_out is only consumed by the texture3d variant below.
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
if (k4_start + k4 < K4 && m_start + m < M) {
if (k4_start + k4 < K4 && m_start + m < M_TILE) {
const int row = m_start + m;
const int col = k4_start + k4;
#if defined(INPUT_BUFFER)
// Cast SSBO texel into the input tile type (f16vec4 for half, vec4 for
// float).
tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(t_in[row * K4 + col]);
#elif defined(INPUT_TEXTURE3D)
// texture3d layout: row (the flat M index) decomposes into (ow, oh)
// and K4 is along the Z axis. texelFetch returns vec4 (fp32); cast to
// the input tile type.
// texture3d scratch [1, K_total, OH_TILE, W_out]: the tile-local row
// decomposes into (ow, oh_local) and K4 is along the Z axis. texelFetch
// returns vec4 (fp32); cast to the input tile type.
tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(
texelFetch(t_in, ivec3(row % W_out, row / W_out, col), 0));
#else
Expand All @@ -141,17 +159,28 @@ void load_input_tile_with_checks(
}
}

// m_start is a tile-local row offset; the scratch read uses it directly, but
// the output store maps it to the GLOBAL spatial position via oh_global =
// OH_OFFSET + (m_local / W_out). Rows whose global oh lands past H_out (the
// partial trailing tile, or a dynamic shape that shrinks H_out) are skipped by
// the `oh < H_out` guard. The companion `m_local < M_TILE` guard enforces the
// tile's UPPER oh bound: since M_TILE = OH_TILE * W_out, m_local < M_TILE means
// oh_local < OH_TILE, i.e. oh < OH_OFFSET + OH_TILE — so no row this dispatch
// writes can leak into a neighboring tile's output-row range.
void store_output_tile_with_checks(
const FPOutTile out_tile,
const int n4_start,
const int m_start,
const int N4,
const int M,
const int M_TILE,
const int H_out,
const int W_out) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
if (m_start + m < M && n4_start + n4 < N4) {
const int spatial = m_start + m;
const int m_local = m_start + m;
const int ow = m_local % W_out;
const int oh = OH_OFFSET + m_local / W_out;
if (m_local < M_TILE && oh < H_out && n4_start + n4 < N4) {
// Cast the accumulator (f16vec4 for the buffer/half path) to the
// texture3d output surface type for the activation clamp and store.
OUT_VEC4_T texel = OUT_VEC4_T(out_tile.data[m][n4]);
Expand All @@ -160,8 +189,7 @@ void store_output_tile_with_checks(
} else if (activation_type == 2) {
texel = clamp(texel, OUT_VEC4_T(OUT_MIN), OUT_VEC4_T(OUT_MAX));
}
imageStore(
t_out, ivec3(spatial % W_out, spatial / W_out, n4_start + n4), texel);
imageStore(t_out, ivec3(ow, oh, n4_start + n4), texel);
}
}
}
Expand All @@ -176,14 +204,16 @@ void main() {

const int W_out = out_sizes.x;
const int H_out = out_sizes.y;
// M = H_out * W_out is derived from the refreshed out_sizes UBO so it tracks
// W_out / H_out are derived from the refreshed out_sizes UBO so they track
// dynamic output shapes (out_sizes is virtual_resize'd on trigger_resize).
const int M = W_out * H_out;
// M_TILE = OH_TILE * W_out is the tile-local row count materialized in the
// im2col scratch (t_in); the GEMM reads scratch rows in [0, M_TILE).
const int M_TILE = OH_TILE * W_out;
const int K4 = K4_TOTAL;
const int N = out_sizes.z;
const int N4 = div_up_4(N);

if (n4_start >= N4 || m_start >= M) {
if (n4_start >= N4 || m_start >= M_TILE) {
return;
}

Expand All @@ -194,7 +224,7 @@ void main() {
FPWeightTile w_tile;

for (int k4 = 0; k4 < K4; k4 += TILE_K4) {
load_input_tile_with_checks(in_tile, k4, m_start, K4, M, W_out);
load_input_tile_with_checks(in_tile, k4, m_start, K4, M_TILE, W_out);
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
}
Expand All @@ -213,5 +243,6 @@ void main() {
}
}

store_output_tile_with_checks(out_tile, n4_start, m_start, N4, M, W_out);
store_output_tile_with_checks(
out_tile, n4_start, m_start, N4, M_TILE, H_out, W_out);
}
64 changes: 43 additions & 21 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,37 @@
/*
* Im2col transformation for FP32 / FP16 conv2d.
*
* The output is a 2D matrix of shape [M, K_total] where
* One dispatch materializes a tile of OH_TILE output-height rows (full W_out
* each) of the im2col matrix, starting at output-height row OH_OFFSET. The full
* matrix has logical shape [M, K_total] where
* M = H_out * W_out (number of output spatial positions)
* K_total = Kh * Kw * align_up_4(C_in) (flattened receptive field)
*
* Tiling by output-height rows bounds the scratch tensor to a fixed byte budget
* regardless of resolution: the scratch holds OH_TILE * W_out rows, not M. A
* tile-local row m_local decodes to oh_local = m_local / W_out,
* ow = m_local % W_out; the SOURCE spatial position uses the global output row
* oh = OH_OFFSET + oh_local. Tiling by H rows (rather than flat M rows) keeps
* this row->(oh,ow) decode exact for the spatial texture3d layout too. When
* tiling is disabled the caller passes OH_OFFSET = 0 and OH_TILE = H_out.
*
* K layout (so a 4-tile in K — one vec4 — holds the same kernel position):
* K = (ki * Kw + kj) * Cin_padded + ci
*
* Three codegen'd storage variants of the output tensor:
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for spatial
* position m. Extents = (K_total/4, M).
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
* for output spatial position (oh, ow). Extents = (W_out, H_out, K4).
* Used as a fallback when M would exceed max_texture2d_dim.
* - buffer: vec4 at offset (m * K4 + k4), same K packing.
* - texture2d, width-packed: texel at (k4, m_local) holds 4 K values for
* tile-local row m_local. Extents = (K_total/4, OH_TILE * W_out).
* - texture3d, channels-packed: texel at (ow, oh_local, k4) holds 4 K values
* for output spatial position (OH_OFFSET + oh_local, ow). Extents =
* (W_out, OH_TILE, K4). Used as a fallback when the per-tile 2D extent
* (OH_TILE * W_out, K4) would exceed max_texture2d_dim.
* - buffer: vec4 at offset (m_local * K4 + k4), same K packing.
*
* The caller picks storage per device (Mali → buffer; others → texture2d
* when its 2D extents fit, texture3d when its 3D extents fit, else buffer).
* The caller selects storage on the TILED scratch extent, not the full M.
* Per device: Mali → buffer; others → texture2d when the per-tile 2D extent
* (OH_TILE * W_out, K4) fits, texture3d when its 3D extents fit, else buffer.
* Because OH_TILE is capped by the scratch byte budget, the tiled extent rarely
* exceeds max_texture2d_dim, so texture2d is the common case.
*/

#version 450 core
Expand Down Expand Up @@ -62,7 +76,7 @@ ${layout_declare_ubo(B, "ivec4", "in_sizes")}
layout(push_constant) uniform restrict Block {
ivec4 kernel_stride; // (Kh, Kw, Sh, Sw)
ivec4 padding_dil; // (Ph, Pw, Dh, Dw)
ivec4 dims; // (Cin_padded, _unused, _unused, K4_total)
ivec4 dims; // (Cin_padded, OH_OFFSET, OH_TILE, K4_total)
};

#define KERNEL_H kernel_stride.x
Expand All @@ -74,13 +88,17 @@ layout(push_constant) uniform restrict Block {
#define DILATION_H padding_dil.z
#define DILATION_W padding_dil.w
#define CIN_PADDED dims.x
#define OH_OFFSET dims.y
#define OH_TILE dims.z
#define K4_TOTAL dims.w

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const int k4 = int(gl_GlobalInvocationID.x);
const int m = int(gl_GlobalInvocationID.y);
// gl_GlobalInvocationID.y is the tile-local row m_local within this tile's
// OH_TILE * W_out rows; it maps to the global output row via OH_OFFSET.
const int m_local = int(gl_GlobalInvocationID.y);

// Derive the spatial output extents from the (refreshed-on-resize) input
// sizes UBO so the im2col mapping tracks dynamic input shapes. in_sizes is
Expand All @@ -92,12 +110,20 @@ void main() {
const int H_OUT =
(in_sizes.y + 2 * PADDING_H - DILATION_H * (KERNEL_H - 1) - 1) / STRIDE_H +
1;
const int M = H_OUT * W_OUT;
// Rows materialized by this tile (capped to the scratch extent).
const int M_TILE = OH_TILE * W_OUT;

if (k4 >= K4_TOTAL || m >= M) {
if (k4 >= K4_TOTAL || m_local >= M_TILE) {
return;
}

// Tile-local row m_local -> (oh_local, ow); global output row oh = OH_OFFSET +
// oh_local. Rows past the real H_OUT (in a partial trailing tile, or when a
// dynamic shape shrinks H_OUT below the build-time OH_OFFSET) write zeros.
const int oh_local = m_local / W_OUT;
const int ow = m_local % W_OUT;
const int oh = OH_OFFSET + oh_local;

const int k_start = k4 * 4;

// K = (ki * Kw + kj) * Cin_padded + ci ; since Cin_padded % 4 == 0, all 4
Expand All @@ -109,24 +135,20 @@ void main() {
const int ki = krow_idx / KERNEL_W;
const int ci_blk = ci_start >> 2; // ci_start / 4

// Decompose flat output position m back into (oh, ow).
const int ow = m % W_OUT;
const int oh = m / W_OUT;

// Compute the input spatial position for this (oh, ow, ki, kj).
const int ih = oh * STRIDE_H - PADDING_H + ki * DILATION_H;
const int iw = ow * STRIDE_W - PADDING_W + kj * DILATION_W;

VEC4_T out_texel = VEC4_T(0);
if (ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) {
if (oh < H_OUT && ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) {
out_texel = texelFetch(t_in, ivec3(iw, ih, ci_blk), 0);
}

#if defined(OUTPUT_BUFFER)
t_out[m * K4_TOTAL + k4] = VEC4_BUF_T(out_texel);
t_out[m_local * K4_TOTAL + k4] = VEC4_BUF_T(out_texel);
#elif defined(OUTPUT_TEXTURE3D)
imageStore(t_out, ivec3(ow, oh, k4), out_texel);
imageStore(t_out, ivec3(ow, oh_local, k4), out_texel);
#else
imageStore(t_out, ivec2(k4, m), out_texel);
imageStore(t_out, ivec2(k4, m_local), out_texel);
#endif
}
Loading
Loading