diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl index 42f7e5a85cc..6eebdb6b25f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_gemm.glsl @@ -9,23 +9,30 @@ /* * conv2d_gemm: GEMM step of im2col-backed conv2d. * - * Reads the im2col'd input produced by conv2d_im2col.glsl as a 2D matrix - * of shape [M, K_total] (M = H_out * W_out, K_total = Kh*Kw*Cin_padded) - * and writes the conv2d output as texture3D channels-packed - * logical shape [1, C_out, H_out, W_out]. + * Reads one tile of the im2col'd input produced by conv2d_im2col.glsl — a 2D + * matrix of shape [M_TILE, K_total] holding OH_TILE output-height rows + * (M_TILE = OH_TILE * W_out, K_total = Kh*Kw*Cin_padded) starting at output + * row OH_OFFSET — and writes the corresponding output rows as texture3D + * channels-packed, logical shape [1, C_out, H_out, W_out]. The full im2col + * matrix is processed OH_TILE rows per dispatch so the scratch tensor is bounded + * to a fixed byte budget regardless of resolution; with tiling disabled the + * caller passes OH_OFFSET = 0 and OH_TILE = H_out (one dispatch covers all M). * - * The im2col input can be any of: - * - texture2d, width-packed: texel at (k4, m) holds 4 K values for row m. - * IN_STORAGE=texture2d codegen. - * - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values - * for output spatial position (oh, ow). Used when M would exceed - * max_texture2d_dim. IN_STORAGE=texture3d codegen. - * - buffer: vec4 at offset m*K4 + k4, same K packing. + * The im2col input tile can be any of: + * - texture2d, width-packed: texel at (k4, r) holds 4 K values for tile-local + * row r. IN_STORAGE=texture2d codegen. + * - texture3d, channels-packed: texel at (ow, oh_local, k4) holds 4 K values + * for tile-local row r = oh_local * W_out + ow. Used when the per-tile 2D + * extent (M_TILE = OH_TILE * W_out) would exceed max_texture2d_dim — rare, + * since OH_TILE is capped by the scratch byte budget. IN_STORAGE=texture3d + * codegen. + * - buffer: vec4 at offset r*K4 + k4, same K packing. * IN_STORAGE=buffer codegen. * - * The matmul interpretation is: - * out[m, n] = sum_k im2col[m, k] * weight[n, k] + bias[n] - * with M = H_out * W_out, K = K_total, N = C_out. + * The matmul interpretation (over this tile's rows) is: + * out[r, n] = sum_k im2col[r, k] * weight[n, k] + bias[n] + * with K = K_total, N = C_out, and r the tile-local row whose global output + * spatial position is (OH_OFFSET + r / W_out, r % W_out). */ #version 450 core @@ -85,14 +92,23 @@ ${layout_declare_ubo(B, "ivec4", "out_sizes")} // dims), so it is safe to bake at build time even under dynamic shapes. // M = H_out * W_out IS shape-dependent, so it is derived at runtime from the // refreshed out_sizes UBO in main() rather than read from here. +// +// This dispatch consumes one tile of the im2col matrix: OH_TILE output-height +// rows starting at output-height row OH_OFFSET. The im2col scratch (t_in) holds +// OH_TILE * W_out tile-local rows; the GEMM reads tile-local rows and writes the +// output at the corresponding global spatial position (OH_OFFSET + oh_local, +// ow). OH_OFFSET / OH_TILE are shape-independent (fixed at build time); the +// global W_out / H_out come from the refreshed out_sizes UBO. layout(push_constant) uniform restrict Block { - ivec4 gemm_dims; // (K4_total, _unused, _unused, _unused) + ivec4 gemm_dims; // (K4_total, OH_OFFSET, OH_TILE, _unused) vec4 clamp_vals; // (out_min, out_max, _unused, _unused) }; -#define K4_TOTAL gemm_dims.x -#define OUT_MIN clamp_vals.x -#define OUT_MAX clamp_vals.y +#define K4_TOTAL gemm_dims.x +#define OH_OFFSET gemm_dims.y +#define OH_TILE gemm_dims.z +#define OUT_MIN clamp_vals.x +#define OUT_MAX clamp_vals.y layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -104,20 +120,22 @@ ${layout_declare_spec_const(C, "int", "activation_type", "0")} /* * Load TILE_M rows × TILE_K4 K-tiles of the im2col'd input. - * The im2col output is a contiguous (M, K_total/4) matrix of vec4s, so the - * load is a plain 2D fetch — no spatial decomposition. + * The im2col scratch holds M_TILE tile-local rows in a contiguous + * (M_TILE, K_total/4) matrix of vec4s; row here is the tile-local index, so the + * load is a plain 2D fetch — no spatial decomposition. (The output store, not + * this load, maps the tile-local row to its global spatial position.) */ void load_input_tile_with_checks( out FPInputTile tile, const int k4_start, const int m_start, const int K4, - const int M, + const int M_TILE, const int W_out) { // W_out is only consumed by the texture3d variant below. [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { - if (k4_start + k4 < K4 && m_start + m < M) { + if (k4_start + k4 < K4 && m_start + m < M_TILE) { const int row = m_start + m; const int col = k4_start + k4; #if defined(INPUT_BUFFER) @@ -125,9 +143,9 @@ void load_input_tile_with_checks( // float). tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T(t_in[row * K4 + col]); #elif defined(INPUT_TEXTURE3D) - // texture3d layout: row (the flat M index) decomposes into (ow, oh) - // and K4 is along the Z axis. texelFetch returns vec4 (fp32); cast to - // the input tile type. + // texture3d scratch [1, K_total, OH_TILE, W_out]: the tile-local row + // decomposes into (ow, oh_local) and K4 is along the Z axis. texelFetch + // returns vec4 (fp32); cast to the input tile type. tile.data[m][k4] = LINEAR_FP_INPUT_TILE_VEC4_T( texelFetch(t_in, ivec3(row % W_out, row / W_out, col), 0)); #else @@ -141,17 +159,28 @@ void load_input_tile_with_checks( } } +// m_start is a tile-local row offset; the scratch read uses it directly, but +// the output store maps it to the GLOBAL spatial position via oh_global = +// OH_OFFSET + (m_local / W_out). Rows whose global oh lands past H_out (the +// partial trailing tile, or a dynamic shape that shrinks H_out) are skipped by +// the `oh < H_out` guard. The companion `m_local < M_TILE` guard enforces the +// tile's UPPER oh bound: since M_TILE = OH_TILE * W_out, m_local < M_TILE means +// oh_local < OH_TILE, i.e. oh < OH_OFFSET + OH_TILE — so no row this dispatch +// writes can leak into a neighboring tile's output-row range. void store_output_tile_with_checks( const FPOutTile out_tile, const int n4_start, const int m_start, const int N4, - const int M, + const int M_TILE, + const int H_out, const int W_out) { [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - if (m_start + m < M && n4_start + n4 < N4) { - const int spatial = m_start + m; + const int m_local = m_start + m; + const int ow = m_local % W_out; + const int oh = OH_OFFSET + m_local / W_out; + if (m_local < M_TILE && oh < H_out && n4_start + n4 < N4) { // Cast the accumulator (f16vec4 for the buffer/half path) to the // texture3d output surface type for the activation clamp and store. OUT_VEC4_T texel = OUT_VEC4_T(out_tile.data[m][n4]); @@ -160,8 +189,7 @@ void store_output_tile_with_checks( } else if (activation_type == 2) { texel = clamp(texel, OUT_VEC4_T(OUT_MIN), OUT_VEC4_T(OUT_MAX)); } - imageStore( - t_out, ivec3(spatial % W_out, spatial / W_out, n4_start + n4), texel); + imageStore(t_out, ivec3(ow, oh, n4_start + n4), texel); } } } @@ -176,14 +204,16 @@ void main() { const int W_out = out_sizes.x; const int H_out = out_sizes.y; - // M = H_out * W_out is derived from the refreshed out_sizes UBO so it tracks + // W_out / H_out are derived from the refreshed out_sizes UBO so they track // dynamic output shapes (out_sizes is virtual_resize'd on trigger_resize). - const int M = W_out * H_out; + // M_TILE = OH_TILE * W_out is the tile-local row count materialized in the + // im2col scratch (t_in); the GEMM reads scratch rows in [0, M_TILE). + const int M_TILE = OH_TILE * W_out; const int K4 = K4_TOTAL; const int N = out_sizes.z; const int N4 = div_up_4(N); - if (n4_start >= N4 || m_start >= M) { + if (n4_start >= N4 || m_start >= M_TILE) { return; } @@ -194,7 +224,7 @@ void main() { FPWeightTile w_tile; for (int k4 = 0; k4 < K4; k4 += TILE_K4) { - load_input_tile_with_checks(in_tile, k4, m_start, K4, M, W_out); + load_input_tile_with_checks(in_tile, k4, m_start, K4, M_TILE, W_out); load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4); fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile); } @@ -213,5 +243,6 @@ void main() { } } - store_output_tile_with_checks(out_tile, n4_start, m_start, N4, M, W_out); + store_output_tile_with_checks( + out_tile, n4_start, m_start, N4, M_TILE, H_out, W_out); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl index 84bd77ab3a6..702a0671964 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_im2col.glsl @@ -9,23 +9,37 @@ /* * Im2col transformation for FP32 / FP16 conv2d. * - * The output is a 2D matrix of shape [M, K_total] where + * One dispatch materializes a tile of OH_TILE output-height rows (full W_out + * each) of the im2col matrix, starting at output-height row OH_OFFSET. The full + * matrix has logical shape [M, K_total] where * M = H_out * W_out (number of output spatial positions) * K_total = Kh * Kw * align_up_4(C_in) (flattened receptive field) * + * Tiling by output-height rows bounds the scratch tensor to a fixed byte budget + * regardless of resolution: the scratch holds OH_TILE * W_out rows, not M. A + * tile-local row m_local decodes to oh_local = m_local / W_out, + * ow = m_local % W_out; the SOURCE spatial position uses the global output row + * oh = OH_OFFSET + oh_local. Tiling by H rows (rather than flat M rows) keeps + * this row->(oh,ow) decode exact for the spatial texture3d layout too. When + * tiling is disabled the caller passes OH_OFFSET = 0 and OH_TILE = H_out. + * * K layout (so a 4-tile in K — one vec4 — holds the same kernel position): * K = (ki * Kw + kj) * Cin_padded + ci * * Three codegen'd storage variants of the output tensor: - * - texture2d, width-packed: texel at (k4, m) holds 4 K values for spatial - * position m. Extents = (K_total/4, M). - * - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values - * for output spatial position (oh, ow). Extents = (W_out, H_out, K4). - * Used as a fallback when M would exceed max_texture2d_dim. - * - buffer: vec4 at offset (m * K4 + k4), same K packing. + * - texture2d, width-packed: texel at (k4, m_local) holds 4 K values for + * tile-local row m_local. Extents = (K_total/4, OH_TILE * W_out). + * - texture3d, channels-packed: texel at (ow, oh_local, k4) holds 4 K values + * for output spatial position (OH_OFFSET + oh_local, ow). Extents = + * (W_out, OH_TILE, K4). Used as a fallback when the per-tile 2D extent + * (OH_TILE * W_out, K4) would exceed max_texture2d_dim. + * - buffer: vec4 at offset (m_local * K4 + k4), same K packing. * - * The caller picks storage per device (Mali → buffer; others → texture2d - * when its 2D extents fit, texture3d when its 3D extents fit, else buffer). + * The caller selects storage on the TILED scratch extent, not the full M. + * Per device: Mali → buffer; others → texture2d when the per-tile 2D extent + * (OH_TILE * W_out, K4) fits, texture3d when its 3D extents fit, else buffer. + * Because OH_TILE is capped by the scratch byte budget, the tiled extent rarely + * exceeds max_texture2d_dim, so texture2d is the common case. */ #version 450 core @@ -62,7 +76,7 @@ ${layout_declare_ubo(B, "ivec4", "in_sizes")} layout(push_constant) uniform restrict Block { ivec4 kernel_stride; // (Kh, Kw, Sh, Sw) ivec4 padding_dil; // (Ph, Pw, Dh, Dw) - ivec4 dims; // (Cin_padded, _unused, _unused, K4_total) + ivec4 dims; // (Cin_padded, OH_OFFSET, OH_TILE, K4_total) }; #define KERNEL_H kernel_stride.x @@ -74,13 +88,17 @@ layout(push_constant) uniform restrict Block { #define DILATION_H padding_dil.z #define DILATION_W padding_dil.w #define CIN_PADDED dims.x +#define OH_OFFSET dims.y +#define OH_TILE dims.z #define K4_TOTAL dims.w layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const int k4 = int(gl_GlobalInvocationID.x); - const int m = int(gl_GlobalInvocationID.y); + // gl_GlobalInvocationID.y is the tile-local row m_local within this tile's + // OH_TILE * W_out rows; it maps to the global output row via OH_OFFSET. + const int m_local = int(gl_GlobalInvocationID.y); // Derive the spatial output extents from the (refreshed-on-resize) input // sizes UBO so the im2col mapping tracks dynamic input shapes. in_sizes is @@ -92,12 +110,20 @@ void main() { const int H_OUT = (in_sizes.y + 2 * PADDING_H - DILATION_H * (KERNEL_H - 1) - 1) / STRIDE_H + 1; - const int M = H_OUT * W_OUT; + // Rows materialized by this tile (capped to the scratch extent). + const int M_TILE = OH_TILE * W_OUT; - if (k4 >= K4_TOTAL || m >= M) { + if (k4 >= K4_TOTAL || m_local >= M_TILE) { return; } + // Tile-local row m_local -> (oh_local, ow); global output row oh = OH_OFFSET + + // oh_local. Rows past the real H_OUT (in a partial trailing tile, or when a + // dynamic shape shrinks H_OUT below the build-time OH_OFFSET) write zeros. + const int oh_local = m_local / W_OUT; + const int ow = m_local % W_OUT; + const int oh = OH_OFFSET + oh_local; + const int k_start = k4 * 4; // K = (ki * Kw + kj) * Cin_padded + ci ; since Cin_padded % 4 == 0, all 4 @@ -109,24 +135,20 @@ void main() { const int ki = krow_idx / KERNEL_W; const int ci_blk = ci_start >> 2; // ci_start / 4 - // Decompose flat output position m back into (oh, ow). - const int ow = m % W_OUT; - const int oh = m / W_OUT; - // Compute the input spatial position for this (oh, ow, ki, kj). const int ih = oh * STRIDE_H - PADDING_H + ki * DILATION_H; const int iw = ow * STRIDE_W - PADDING_W + kj * DILATION_W; VEC4_T out_texel = VEC4_T(0); - if (ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) { + if (oh < H_OUT && ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) { out_texel = texelFetch(t_in, ivec3(iw, ih, ci_blk), 0); } #if defined(OUTPUT_BUFFER) - t_out[m * K4_TOTAL + k4] = VEC4_BUF_T(out_texel); + t_out[m_local * K4_TOTAL + k4] = VEC4_BUF_T(out_texel); #elif defined(OUTPUT_TEXTURE3D) - imageStore(t_out, ivec3(ow, oh, k4), out_texel); + imageStore(t_out, ivec3(ow, oh_local, k4), out_texel); #else - imageStore(t_out, ivec2(k4, m), out_texel); + imageStore(t_out, ivec2(k4, m_local), out_texel); #endif } diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.cpp index 21113d79c01..923c7167876 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.cpp @@ -16,12 +16,30 @@ #include #include +#include #include namespace vkcompute { namespace { +// Byte budget for the im2col scratch tensor. The full im2col matrix is +// [M, K_total] = M * K_total * elem bytes (M = H_out * W_out); at high +// resolution this reaches hundreds of MB (e.g. 144 MB FP32 for a +// [1,64,256,256] 3x3 conv), a non-reclaimable device-local allocation resident +// for the whole model lifetime — an OOM risk on memory-constrained mobile GPUs. +// Materializing the im2col in tiles of output-height rows caps the scratch to +// this budget regardless of resolution while preserving GEMM throughput (the +// GEMM inner loop is unchanged; only the live row count is bounded). Tunable: +// a larger budget means fewer tiles / dispatches per conv but more peak memory. +// +// This is a LOGICAL-size budget (oh_tile is derived from W_out * K_total * +// elem bytes). The physical texture2d / texture3d allocation rounds the packed +// dim up to whole texels (vec4) and adds image row / layer alignment, so actual +// device memory for the scratch can modestly exceed this figure. Treat it as a +// soft tuning knob, not a hard allocation ceiling. +constexpr int64_t kIm2colScratchBudgetBytes = 16 * 1024 * 1024; + // // Weight handling // @@ -140,34 +158,60 @@ vkapi::ShaderInfo pick_conv2d_gemm_shader( return VK_KERNEL_FROM_STR(kernel_name); } +// resize_args = { in, weight_data, stride, padding, dilation, oh_tile, +// oh_offset } +// resize_args[5] / [6] carry the raw oh_tile / oh_offset VALUES (not ValueRef +// handles): both are build-time constants, so packing the ints directly into +// the slots avoids materializing graph Values for them. Read with static_cast, +// never get_int. utils::uvec3 pick_conv2d_gemm_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { (void)shader; - (void)resize_args; const ValueRef out = args.at(0).refs.at(0); const uint32_t W = graph->size_at(-1, out); - const uint32_t H = graph->size_at(-2, out); + const uint32_t H_out = graph->size_at(-2, out); const uint32_t C_out = graph->size_at(-3, out); - const uint32_t M = H * W; + const int32_t oh_tile = static_cast(resize_args.at(5)); + const int32_t oh_offset = static_cast(resize_args.at(6)); + // Dead-tile skip: when a runtime down-resize shrinks H_out so this tile's + // first output row (oh_offset) is already past the live H_out, the whole tile + // is dead. Return a zero-sized global wg so DispatchNode::encode() emits zero + // workgroups for it (it explicitly skips any dispatch whose global wg has a 0 + // component) — no GEMM work, no im2col read. trigger_resize() recomputes this + // and re-encodes the command buffer when the dispatch grid changes, so the + // skip tracks the dynamic shape. (The num_tiles count is still fixed at build + // time; this only zeroes the work of tiles that fall off the live region.) + if (oh_offset >= static_cast(H_out)) { + return {0u, 0u, 0u}; + } + // Every live tile dispatches oh_tile output-height rows (oh_tile * W per-tile + // M); trailing threads past the real H_out no-op in the shader. + const uint32_t M_tile = static_cast(oh_tile) * W; const uint32_t N4 = utils::div_up_4(C_out); // TILE_N4=1, TILE_M=4 - return {N4, utils::div_up(M, 4u), 1}; + return {N4, utils::div_up(M_tile, 4u), 1}; } // Recompute the conv output sizes from the current input shape and resize the // output tensor. This is the load-bearing resize for the im2col/GEMM path: // under dynamic shapes the graph is built for the upper-bound input, so on // trigger_resize() the output must be recomputed from the real input or it -// stays frozen at the upper bound (producing garbage downstream). +// stays frozen at the upper bound (producing garbage downstream). Every tile's +// GEMM node shares this resize (each writes a different oh-row window of the +// same full output tensor). // -// The GEMM shader derives M = H_out * W_out and the spatial store coordinates -// from the (now-refreshed) out_sizes UBO, so resizing `out` here is sufficient -// to make the GEMM track the dynamic shape — no push-constant update is needed. +// The GEMM shader derives W_out / H_out and the spatial store coordinates from +// the (now-refreshed) out_sizes UBO, so resizing `out` here is sufficient to +// make every tile track the dynamic shape — no push-constant update is needed +// (oh_offset / oh_tile are shape-independent). The global workgroup picker +// reads oh_tile (resize_args[5]) and oh_offset (resize_args[6]) to size each +// tile's dispatch and to zero-size dead trailing tiles after a down-resize. // -// resize_args = { in, weight_data, stride, padding, dilation } +// resize_args = { in, weight_data, stride, padding, dilation, oh_tile, +// oh_offset } void resize_conv2d_gemm_node( ComputeGraph* graph, const std::vector& args, @@ -220,15 +264,31 @@ void add_conv2d_gemm_node( const int32_t K_total, const bool clamp_out, const float out_min_val, - const float out_max_val) { + const float out_max_val, + const int32_t oh_offset, + const int32_t oh_tile) { const int32_t K4_total = K_total / 4; - // gemm_dims carries only the shape-independent K4_total. M is derived in the - // shader from the refreshed out_sizes UBO, so it is not baked here (a baked - // plain-data push constant cannot be updated on resize). - const utils::ivec4 gemm_dims{K4_total, 0, 0, 0}; + // gemm_dims = (K4_total, oh_offset, oh_tile, _unused). All shape-independent: + // this tile reads scratch rows for oh_tile output-height rows and writes the + // output rows starting at oh_offset. W_out / H_out are derived in the shader + // from the refreshed out_sizes UBO (a baked plain-data push constant cannot + // be updated on resize, but oh_offset / oh_tile never change with shape). + const utils::ivec4 gemm_dims{K4_total, oh_offset, oh_tile, 0}; const utils::vec4 clamp_vals{out_min_val, out_max_val, 0.0f, 0.0f}; + // The last two resize_args slots carry the raw oh_tile / oh_offset VALUES, + // not ValueRef handles: both are build-time constants, so packing the ints + // directly avoids materializing graph Values for them. The global-wg picker + // reads them back with static_cast (never get_int) — oh_tile to size the + // dispatch, oh_offset to zero-size a dead trailing tile after a down-resize. + // ExecuteNode's resize dirty-tracker treats these slots as value indices: if + // a packed int happens to collide with a real ValueList index, + // was_value_updated can RECURSE through toConstValueList() to walk that + // list's members, so the spurious over-trigger may be a deeper walk than a + // single lookup. Still memory-safe and read-only (was_value_updated guards + // out-of-range and in-range idx alike) and benign (the resize recomputes + // correctly regardless). graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, pick_conv2d_gemm_shader, @@ -245,8 +305,15 @@ void add_conv2d_gemm_node( // Specialization constants // activation_type: 0=none, 1=relu, 2=clamp {clamp_out ? 2 : 0}, - // Resize args - {in, weight_data, stride, padding, dilation}, + // Resize args (last two slots = raw oh_tile / oh_offset values, see note + // above) + {in, + weight_data, + stride, + padding, + dilation, + static_cast(oh_tile), + static_cast(oh_offset)}, // Resizing logic resize_conv2d_gemm_node)); } @@ -303,16 +370,50 @@ void conv2d_gemm_impl( dilation_w = utils::safe_downcast(dilation_list->at(1)); } - const int64_t M = H_out * W_out; const int64_t K4_total = K_total / 4; + // Tile the im2col by output-height rows so the scratch is bounded to the + // fixed kIm2colScratchBudgetBytes byte budget regardless of resolution. One + // H-row of im2col is W_out * K_total * elem bytes; oh_tile is the most H-rows + // whose im2col fits the budget (>= 1), clamped to H_out. The full conv is + // then materialized in num_tiles = ceil(H_out / oh_tile) tiles. When oh_tile + // >= H_out this reduces to a single untiled dispatch pair. (Why a fixed + // num_tiles is safe under dynamic shapes is documented at the dispatch loop + // below.) + // + // Computed BEFORE storage selection because the chosen storage gates on the + // TILED scratch extent (oh_tile * W_out rows, oh_tile deep), not the full M / + // H_out. oh_tile depends only on W_out, K_total, and elem_size (the input + // dtype) — never on the storage type — so there is no circular dependency in + // ordering it first. + const int64_t elem_size = + utils::safe_downcast(vkapi::element_size(graph.dtype_of(in))); + const int64_t bytes_per_h_row = W_out * K_total * elem_size; + int64_t oh_tile = kIm2colScratchBudgetBytes / bytes_per_h_row; + oh_tile = std::max(oh_tile, 1); + oh_tile = std::min(oh_tile, H_out); + const int64_t num_tiles = utils::div_up(H_out, oh_tile); + + // The per-tile scratch holds only oh_tile output-height rows, so its extents + // are M_tile = oh_tile * W_out (texture2d / buffer) or oh_tile-deep + // (texture3d), not the full M / H_out. With the budget capping oh_tile, the + // tiled extent rarely exceeds max_texture2d_dim, so texture2d is selected in + // the common case. + + const int64_t M_tile = oh_tile * W_out; + + // oh_tile reaches the resize fn / wg pickers as the raw int packed into the + // last resize_args slot (see add_conv2d_*_node) — no materialized graph + // Value. The scratch's W_out-dependent extent tracks dynamic shapes while + // oh_tile stays fixed (it is a build-time constant). + // Pick im2col storage. When an explicit override is provided (test-only), // honor it and skip auto-selection. Otherwise run the production // auto-selection per device: // - Mali: always buffer (texture sampling on Mali is comparatively slow). - // - Others: prefer texture2d (M × K4_total). If that doesn't fit the - // device's max texture2d dim, fall back to texture3d laid out as - // (W_out, H_out, K4_total). Buffer is the last-resort fallback. + // - Others: prefer texture2d (M_tile × K4_total). If the tiled extent + // doesn't fit the device's max texture2d dim, fall back to texture3d laid + // out as (W_out, oh_tile, K4_total). Buffer is the last-resort fallback. utils::StorageType im2col_storage; if (im2col_storage_override.has_value()) { im2col_storage = im2col_storage_override.value(); @@ -326,9 +427,9 @@ void conv2d_gemm_impl( const uint32_t max_2d = graph.context()->adapter_ptr()->max_texture2d_dim(); const uint32_t max_3d = graph.context()->adapter_ptr()->max_texture3d_dim(); const bool fits_2d = utils::safe_downcast(K4_total) <= max_2d && - utils::safe_downcast(M) <= max_2d; + utils::safe_downcast(M_tile) <= max_2d; const bool fits_3d = utils::safe_downcast(W_out) <= max_3d && - utils::safe_downcast(H_out) <= max_3d && + utils::safe_downcast(oh_tile) <= max_3d && utils::safe_downcast(K4_total) <= max_3d; if (fits_2d) { im2col_storage = utils::kTexture2D; @@ -339,29 +440,30 @@ void conv2d_gemm_impl( } } - // Allocate the im2col intermediate as a scoped scratch tensor. The im2col - // value is produced by the im2col node and consumed immediately by the GEMM - // node, both below, and is dead afterwards. Using a TmpTensor lets the memory - // planner alias one backing buffer across the (non-overlapping) im2col - // lifetimes of every conv2d layer, so peak memory tracks the largest single - // im2col rather than the sum of all of them. The TmpTensor must outlive - // add_conv2d_gemm_node (its last consumer), so it lives to the end of this - // function. + // Allocate ONE im2col scratch tensor sized for a single tile (oh_tile rows), + // reused across all tiles. The im2col value is produced by each tile's im2col + // node and consumed immediately by that tile's GEMM node; reusing the same + // TmpTensor across tiles serializes them via the backend's automatic + // read/write barriers (tile t's GEMM finishes reading before tile t+1's + // im2col overwrites). Using a TmpTensor also lets the memory planner alias + // one backing buffer across the non-overlapping im2col lifetimes of every + // conv2d layer, so peak memory tracks the largest single tile's scratch (<= + // budget) rather than the sum. The TmpTensor must outlive the last GEMM node, + // so it lives to the end of this function. // - // The 2D and buffer variants use a flat [M, K_total] kWidthPacked shape; the - // texture3d variant uses the natural [1, K_total, H_out, W_out] - // kChannelsPacked shape so K4 lays along Z. Hoist the per-storage differences - // into locals so the TmpTensor is constructed exactly once and never needs to - // be copied or moved. + // The 2D and buffer variants use a flat [oh_tile * W_out, K_total] + // kWidthPacked shape; the texture3d variant uses [1, K_total, oh_tile, W_out] + // kChannelsPacked so K4 lays along Z. Hoist the per-storage differences into + // locals so the TmpTensor is constructed exactly once and never copied/moved. std::vector im2col_sizes; utils::StorageType im2col_tmp_storage; utils::GPUMemoryLayout im2col_layout; if (im2col_storage == utils::kTexture3D) { - im2col_sizes = {1, K_total, H_out, W_out}; + im2col_sizes = {1, K_total, oh_tile, W_out}; im2col_tmp_storage = utils::kTexture3D; im2col_layout = utils::kChannelsPacked; } else { - im2col_sizes = {M, K_total}; + im2col_sizes = {oh_tile * W_out, K_total}; im2col_tmp_storage = im2col_storage; im2col_layout = utils::kWidthPacked; } @@ -373,30 +475,11 @@ void conv2d_gemm_impl( im2col_layout); const ValueRef im2col_tensor = im2col_tmp.vref; - // Step 1: im2col - add_conv2d_im2col_node( - graph, - in, - im2col_tensor, - weight_data, - stride, - padding, - dilation, - utils::safe_downcast(K_h), - utils::safe_downcast(K_w), - stride_h, - stride_w, - padding_h, - padding_w, - dilation_h, - dilation_w, - utils::safe_downcast(Cin_padded)); - - // Step 2: prepack weight for the GEMM directly from the serialized - // [C_out, C_in, K_h, K_w] weight on the GPU. The serialized data is read - // as-is (never CPU-repacked); the prepack shader does the im2col K-axis - // reorder + 4x4 transpose into the layout conv2d_gemm.glsl loads via - // load_packed_weight_tile_with_checks. + // Prepack weight for the GEMM directly from the serialized + // [C_out, C_in, K_h, K_w] weight on the GPU (shared across all tiles). The + // serialized data is read as-is (never CPU-repacked); the prepack shader does + // the im2col K-axis reorder + 4x4 transpose into the layout conv2d_gemm.glsl + // loads via load_packed_weight_tile_with_checks. ValueRef packed_weight = prepack_conv2d_gemm_weight(graph, weight_data); // Bias prepack: matches the bias format conv2d_gemm expects. prepack_biases @@ -412,22 +495,74 @@ void conv2d_gemm_impl( check_conv_args(graph, in, out); - // Step 3: GEMM - add_conv2d_gemm_node( - graph, - in, - weight_data, - stride, - padding, - dilation, - im2col_tensor, - packed_weight, - packed_bias, - out, - utils::safe_downcast(K_total), - clamp_out, - out_min_val, - out_max_val); + // Emit one (im2col -> GEMM) dispatch pair per tile, interleaved so each + // tile's GEMM reads the scratch its im2col just wrote before the next tile + // overwrites it. + // + // num_tiles (and oh_tile) are fixed here at graph-build time: the per-tile + // dispatch count must be static (DynamicDispatchNode does not add/remove + // nodes on resize). This is correct ONLY because ET-VK builds these tensors + // at the dynamic UPPER BOUND, so trigger_resize() can only shrink H_out/W_out + // (runtime <= build-time). Three consequences make the fixed tiling safe: + // - num_tiles, from the build-time (max) H_out, always covers the runtime + // row count; + // - a smaller runtime H_out just leaves trailing tiles (oh_offset >= the + // current H_out) with no live output rows; the GEMM global-wg picker + // zero-sizes the dispatch for such a tile (DispatchNode::encode() skips a + // 0-component global wg), so a dead trailing tile costs no GEMM work + // after a down-resize. (Its im2col node still dispatches — a cheap + // per-thread gather that writes zeros into the unused scratch rows; the + // static node-count constraint means the im2col node cannot be removed, + // only the dominant GEMM work is elided.) + // - the scratch, sized from the build-time (max) shape, is an upper bound, + // so memory stays capped with no reallocation on resize. + // Load-bearing assumption: if a runtime shape ever EXCEEDED the build-time + // bound, this fixed num_tiles would under-cover and silently drop the extra + // output rows. ET-VK's upper-bound build guarantees this cannot happen; a + // future change that breaks the upper-bound invariant would need a runtime + // tile-count guard here. + for (int64_t t = 0; t < num_tiles; ++t) { + const int32_t oh_offset = utils::safe_downcast(t * oh_tile); + const int32_t oh_tile_i32 = utils::safe_downcast(oh_tile); + + add_conv2d_im2col_node( + graph, + in, + im2col_tensor, + weight_data, + stride, + padding, + dilation, + utils::safe_downcast(K_h), + utils::safe_downcast(K_w), + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + utils::safe_downcast(Cin_padded), + oh_offset, + oh_tile_i32); + + add_conv2d_gemm_node( + graph, + in, + weight_data, + stride, + padding, + dilation, + im2col_tensor, + packed_weight, + packed_bias, + out, + utils::safe_downcast(K_total), + clamp_out, + out_min_val, + out_max_val, + oh_offset, + oh_tile_i32); + } } // diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h b/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h index b0a273b51f4..fc82edb3647 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2dGemm.h @@ -21,28 +21,32 @@ namespace vkcompute { * Dataflow (input t_in [1, C_in, H_in, W_in] -> output t_out [1, C_out, H_out, * W_out]): * - * 1. im2col: add_conv2d_im2col_node (conv2d_im2col.glsl) expands t_in into an - * im2col matrix. The K (reduction) axis is laid out as - * K = (ki * Kw + kj) * Cin_padded + ci, with K_total = Kh * Kw * - * align_up_4(C_in). The im2col storage type selects its shape (see - * conv2d_gemm_impl): - * - buffer / texture2d: flat [M, K_total], width-packed, with - * M = H_out * W_out. - * - texture3d: [1, K_total, H_out, W_out], channels-packed, so the K4 + * The im2col matrix is materialized and consumed in tiles of output-height rows + * (see kIm2colScratchBudgetBytes in conv2d_gemm_impl): each tile covers oh_tile + * output-height rows, and the conv is computed in ceil(H_out / oh_tile) tiles. + * This bounds the im2col scratch to a fixed byte budget regardless of + * resolution. Per tile: + * 1. im2col: add_conv2d_im2col_node (conv2d_im2col.glsl) expands t_in's + * oh_tile-row window into the im2col scratch. The K (reduction) axis is + * laid out as K = (ki * Kw + kj) * Cin_padded + ci, with K_total = Kh * Kw + * * align_up_4(C_in). The im2col storage type selects the scratch shape: + * - buffer / texture2d: flat [oh_tile * W_out, K_total], width-packed. + * - texture3d: [1, K_total, oh_tile, W_out], channels-packed, so the K4 * tiles lay along Z. - * 2. GEMM: add_conv2d_gemm_node (conv2d_gemm.glsl) multiplies the im2col - * matrix by the packed weight [C_out, K_total] to produce t_out. The - * packed weight is prepacked on the GPU directly from the serialized - * [C_out, C_in, Kh, Kw] weight (no CPU repack), applying the im2col K-axis - * decode plus a 4OC x 4IC blocked transpose. + * 2. GEMM: add_conv2d_gemm_node (conv2d_gemm.glsl) multiplies the im2col tile + * by the packed weight [C_out, K_total] and writes that tile's output rows + * of t_out. The packed weight is prepacked on the GPU directly from the + * serialized [C_out, C_in, Kh, Kw] weight (no CPU repack), applying the + * im2col K-axis decode plus a 4OC x 4IC blocked transpose. * - * This function performs both dispatch and prepack registration. The im2col - * intermediate is allocated as a scoped TmpTensor scratch tensor, so the memory - * planner can alias one backing buffer across the non-overlapping im2col - * lifetimes of every conv2d layer (peak memory tracks the largest single im2col - * rather than the sum). The packed weight is produced by a GPU prepack node - * (PrepackNode running pack_conv2d_gemm_weight.glsl) from the serialized - * weight. + * This function performs both dispatch and prepack registration. A single + * im2col scratch TmpTensor (sized for one tile) is reused across all tiles — + * reusing it serializes the tiles via the backend's read/write barriers and + * lets the memory planner alias one backing buffer across the non-overlapping + * im2col lifetimes of every conv2d layer (peak memory tracks the largest single + * tile's scratch — at most the budget — rather than the sum). The packed weight + * (shared across tiles) is produced once by a GPU prepack node (PrepackNode + * running pack_conv2d_gemm_weight.glsl) from the serialized weight. * * Constraints (asserted internally): * - input batch == 1 diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.cpp index 0da72607334..062c85c9edd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.cpp @@ -15,14 +15,17 @@ namespace vkcompute { namespace { -// Compute the im2col output extents (M = H_out * W_out, K4_total) from the -// im2col_out tensor's current sizes. The tensor is virtually resized on +// Compute the im2col scratch extents (M_tile = OH_tile * W_out, K4_total) from +// the im2col_out tensor's current sizes. The tensor is virtually resized on // trigger_resize (see resize_conv2d_im2col_node), so reading from it tracks -// dynamic shapes. +// dynamic shapes. With H-row tiling the scratch holds OH_tile (not H_out) +// output-height rows, so M here is OH_tile * W_out, the per-tile dispatch +// extent (every tile dispatches the same OH_tile rows; trailing threads past +// the real H_out no-op in the shader). // // Two layouts are possible: -// - flat [M, K_total] (buffer / texture2d) -// - [1, K_total, H_out, W_out] (texture3d) +// - flat [OH_tile * W_out, K_total] (buffer / texture2d) +// - [1, K_total, OH_tile, W_out] (texture3d) struct Im2colExtents { uint32_t m; uint32_t k4_total; @@ -33,13 +36,13 @@ Im2colExtents im2col_extents_of(ComputeGraph* graph, const ValueRef im2col) { uint32_t m; uint32_t k_total; if (sizes.size() == 4) { - // texture3d [1, K_total, H_out, W_out] - const int64_t h_out = sizes.at(2); + // texture3d [1, K_total, OH_tile, W_out] + const int64_t oh_tile = sizes.at(2); const int64_t w_out = sizes.at(3); - m = utils::safe_downcast(h_out * w_out); + m = utils::safe_downcast(oh_tile * w_out); k_total = utils::safe_downcast(sizes.at(1)); } else { - // flat [M, K_total] + // flat [OH_tile * W_out, K_total] m = utils::safe_downcast(sizes.at(0)); k_total = utils::safe_downcast(sizes.at(1)); } @@ -75,14 +78,22 @@ utils::uvec3 pick_conv2d_im2col_local_wg_size( return {16u, 4u, 1u}; } -// Recompute the im2col output spatial extents from the current input shape and -// virtually resize the im2col tensor. Both possible layouts must be handled: -// - flat [M, K_total] -> resize dim 0 (M = H_out * W_out) -// - [1, K_total, H_out, W_out] -> resize dims 2/3 (H_out, W_out) -// K_total / Cin_padded are shape-independent, so the K dimension is preserved -// from the current tensor sizes. +// Recompute the im2col scratch extents from the current input shape and +// virtually resize the im2col tensor. The scratch holds OH_tile output-height +// rows (not H_out) — H-row tiling bounds it to a fixed byte budget. Only the +// W_out-dependent extent tracks the dynamic shape; OH_tile is a build-time +// constant (the row capacity). Both layouts: +// - flat [OH_tile * W_out, K_total] -> resize dim 0 (= OH_tile * W_out) +// - [1, K_total, OH_tile, W_out] -> resize dim 3 (= W_out); dim 2 fixed +// K_total / Cin_padded / OH_tile are shape-independent. // -// resize_args = { in, weight_data, stride, padding, dilation } +// resize_args = { in, weight_data, stride, padding, dilation, oh_tile } +// resize_args[5] carries the raw oh_tile VALUE (not a ValueRef handle): oh_tile +// is a build-time constant, so packing the int directly into the slot avoids +// materializing a graph Value for it. Read it with static_cast, never get_int. +// ExecuteNode's resize dirty-tracker treats this slot as a value index and may +// spuriously over-trigger this resize fn — deliberate and benign (it recomputes +// correctly regardless) and memory-safe (was_value_updated guards the lookup). void resize_conv2d_im2col_node( ComputeGraph* graph, const std::vector& args, @@ -93,13 +104,14 @@ void resize_conv2d_im2col_node( const ValueRef stride = resize_args.at(2); const ValueRef padding = resize_args.at(3); const ValueRef dilation = resize_args.at(4); + const int64_t oh_tile = static_cast(resize_args.at(5)); const std::vector in_sizes = graph->sizes_of(in); - // Height / Width from the current input, via the shared conv-output helper - // (same H/W split + formula the direct-conv resize uses). kernel_size is read - // from the weight dims; stride/padding/dilation from the original IntList - // ValueRefs. All are shape-independent — only H_in / W_in change at runtime. + // Width from the current input, via the shared conv-output helper (same H/W + // split + formula the direct-conv resize uses). kernel_size is read from the + // weight dims; stride/padding/dilation from the original IntList ValueRefs. + // All are shape-independent — only H_in / W_in change at runtime. // transposed=false, and the args[3] slot (consulted only as an optional // ceil_mode) is a non-bool ValueRef, so ceil_mode resolves to false. const std::vector out_hw = calc_out_sizes_hw( @@ -109,18 +121,17 @@ void resize_conv2d_im2col_node( /*kernel_size_only=*/false, {stride, padding, dilation, dilation}, /*transposed=*/false); - const int64_t H_out = out_hw.at(0); const int64_t W_out = out_hw.at(1); const std::vector cur_sizes = graph->sizes_of(im2col_out); std::vector new_sizes = cur_sizes; if (cur_sizes.size() == 4) { - // texture3d [1, K_total, H_out, W_out]: K_total (dim 1) is preserved. - new_sizes.at(2) = H_out; + // texture3d [1, K_total, OH_tile, W_out]: K_total (dim 1) and OH_tile + // (dim 2) are preserved; only W_out tracks the dynamic shape. new_sizes.at(3) = W_out; } else { - // flat [M, K_total]: K_total (dim 1) is preserved. - new_sizes.at(0) = H_out * W_out; + // flat [OH_tile * W_out, K_total]: K_total (dim 1) is preserved. + new_sizes.at(0) = oh_tile * W_out; } graph->virtual_resize(im2col_out, new_sizes); } @@ -130,8 +141,17 @@ void resize_conv2d_im2col_node( // Push constants are uploaded in 16-byte chunks (one ivec4 each) to comply // with the per-entry size limit. Layout matches conv2d_im2col.glsl: // { ivec4 kernel_stride, ivec4 padding_dil, ivec4 dims } -// All fields are shape-independent; W_out / H_out / M are derived in the shader -// from the (resize-refreshed) in_sizes UBO. +// dims carries the per-tile output-height window (oh_offset, oh_tile); W_out / +// H_out are derived in the shader from the (resize-refreshed) in_sizes UBO. All +// dims fields are shape-independent (oh_offset / oh_tile are build-time tile +// constants). +// +// `oh_offset` / `oh_tile` define which output-height rows this dispatch +// materializes: the scratch holds oh_tile rows, written from source rows +// [oh_offset, oh_offset + oh_tile). oh_tile is also packed (as a raw int value, +// not a ValueRef handle) into the last resize_args slot for the resize fn — see +// the resize_conv2d_im2col_node note for the rationale and the benign +// dirty-tracker over-trigger this implies. void add_conv2d_im2col_node( ComputeGraph& graph, @@ -149,7 +169,9 @@ void add_conv2d_im2col_node( const int32_t padding_w, const int32_t dilation_h, const int32_t dilation_w, - const int32_t Cin_padded) { + const int32_t Cin_padded, + const int32_t oh_offset, + const int32_t oh_tile) { const utils::StorageType out_storage = graph.storage_type_of(im2col_out); VK_CHECK_COND( out_storage == utils::kBuffer || out_storage == utils::kTexture2D || @@ -168,10 +190,9 @@ void add_conv2d_im2col_node( const utils::ivec4 kernel_stride{kernel_h, kernel_w, stride_h, stride_w}; const utils::ivec4 padding_dil{padding_h, padding_w, dilation_h, dilation_w}; - // dims.y / dims.z (formerly W_out / H_out) are unused by the shader now — - // the spatial extents are derived at runtime from in_sizes. Only Cin_padded - // and K4_total (both shape-independent) are consumed. - const utils::ivec4 dims{Cin_padded, 0, 0, K4_total}; + // dims = (Cin_padded, oh_offset, oh_tile, K4_total). W_out / H_out are + // derived at runtime from in_sizes; the rest are shape-independent. + const utils::ivec4 dims{Cin_padded, oh_offset, oh_tile, K4_total}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -188,8 +209,13 @@ void add_conv2d_im2col_node( PushConstantDataInfo(&dims, sizeof(dims))}, // Specialization constants {}, - // Resize args - {in, weight_data, stride, padding, dilation}, + // Resize args (last slot = raw oh_tile value, not a ValueRef handle) + {in, + weight_data, + stride, + padding, + dilation, + static_cast(oh_tile)}, // Resizing logic resize_conv2d_im2col_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.h b/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.h index 1f81c29d1e1..89546c72217 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.h +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2dIm2Col.h @@ -13,36 +13,42 @@ namespace vkcompute { /* - * Dispatch a single im2col transformation node for an FP32 / FP16 conv2d. + * Dispatch a single im2col transformation node for one H-row tile of an + * FP32 / FP16 conv2d. * - * Produces a 2D tensor of logical shape - * [M, K_total] - * where - * M = H_out * W_out + * Materializes a 2D tensor of logical shape + * [oh_tile * W_out, K_total] + * holding `oh_tile` output-height rows starting at output-height row + * `oh_offset`, where * K_total = kernel_h * kernel_w * align_up_4(C_in) * + * Tiling by output-height rows bounds the scratch to a fixed byte budget + * regardless of resolution (the full conv covers H_out rows across multiple + * such dispatches); `oh_offset = 0`, `oh_tile = H_out` reproduces the untiled + * single-dispatch case. + * * The K dimension is laid out so that consecutive 4-tiles of K hold 4 * consecutive ci values for the same (ki, kj) kernel position. This is the * layout `conv2d_gemm` consumes for the GEMM step. * - * The im2col output tensor's storage type (texture2d width-packed or - * buffer) is determined by the caller; this function picks the matching - * shader variant based on `graph.storage_type_of(im2col_out)`. + * The im2col output tensor's storage type (texture2d width-packed, buffer, or + * texture3d channels-packed) is determined by the caller; this function picks + * the matching shader variant based on `graph.storage_type_of(im2col_out)`. * - * Dynamic shapes: the spatial output extents (W_out / H_out / M) are derived in - * the shader from the refreshed in_sizes UBO, and the im2col_out tensor is - * virtually resized on every trigger_resize() from the current input shape, so - * this node tracks dynamic input shapes. Cin_padded / K4_total are - * shape-independent and remain baked into the push constant. `stride`, - * `padding`, `dilation` are the original graph ValueRefs (used by the resize - * function to recompute output extents); `weight_data` is the original 4D - * weight (used only for its kernel dims during resize). + * Dynamic shapes: W_out / H_out are derived in the shader from the refreshed + * in_sizes UBO, and the im2col_out tensor is virtually resized on every + * trigger_resize() (its W_out-dependent extent tracks the current input shape; + * oh_tile is fixed). Cin_padded / K4_total / oh_offset / oh_tile are + * shape-independent and baked into the push constant. `stride`, `padding`, + * `dilation` are the original graph ValueRefs (used by the resize function); + * `weight_data` is the original 4D weight (used only for its kernel dims during + * resize). * * Inputs: * in : input texture3D channels-packed [1, C_in, H_in, W_in] - * im2col_out : output tensor (caller allocates), [M, K_total] for - * buffer/texture2d (kWidthPacked) or [1, K_total, H_out, W_out] - * for texture3d (kChannelsPacked) + * im2col_out : scratch tensor (caller allocates), [oh_tile * W_out, K_total] + * for buffer/texture2d (kWidthPacked) or + * [1, K_total, oh_tile, W_out] for texture3d (kChannelsPacked) * weight_data : original [C_out, C_in, kernel_h, kernel_w] weight * stride/padding/dilation : original conv param IntList ValueRefs * kernel_h/w : conv kernel dimensions @@ -50,6 +56,10 @@ namespace vkcompute { * padding_* : conv paddings * dilation_* : conv dilations * Cin_padded : align_up_4(C_in) + * oh_offset : first output-height row this tile materializes + * oh_tile : number of output-height rows in this tile (scratch capacity); + * also packed as a raw int into the last resize_args slot for + * the resize fn (not a materialized ValueRef handle) */ void add_conv2d_im2col_node( ComputeGraph& graph, @@ -67,6 +77,8 @@ void add_conv2d_im2col_node( const int32_t padding_w, const int32_t dilation_h, const int32_t dilation_w, - const int32_t Cin_padded); + const int32_t Cin_padded, + const int32_t oh_offset, + const int32_t oh_tile); } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/conv2d_gemm_dynamic_test.cpp b/backends/vulkan/test/op_tests/conv2d_gemm_dynamic_test.cpp new file mode 100644 index 00000000000..668ed7fc83c --- /dev/null +++ b/backends/vulkan/test/op_tests/conv2d_gemm_dynamic_test.cpp @@ -0,0 +1,541 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; +using namespace vkcompute; + +// +// Dynamic-shape (resize) test for the tiled im2col conv2d path. +// +// The im2col + GEMM conv path (conv2d_gemm_impl) materializes its im2col matrix +// in tiles of output-height rows bounded to a byte budget. The static op-test +// suite (test/custom_ops/test_conv2d.cpp) only validates a single shape per +// graph. This test exercises the path under an actual trigger_resize: it builds +// the graph at an upper-bound input shape that forces MULTIPLE tiles, then +// resizes the input across tile boundaries (so trailing tiles must no-op via +// the shader's `oh < H_out` guard and the scratch must track the smaller shape) +// and verifies the output against a reference at every shape. +// +// Build-time upper bound: storage and the fixed num_tiles are built at the +// initial input shape, so every resized shape MUST be <= the initial shape per +// dim (resize-down is the supported direction; resizing back up toward — never +// above — the bound is fine). The harness asserts this so misuse fails loudly. +// +// Multi-tiling is forced by a real shape at the production 16 MB budget (no +// test knob): C_in=64, 3x3, 128x128 gives K_total = 9*64 = 576, so one +// output-height row of im2col is W_out * K_total * elem = 128 * 576 * 4 = 288 +// KB; oh_tile = 16 MB / 288 KB = 56 rows, and H_out=128 needs ceil(128/56) = 3 +// tiles. The resize sweep (H = 64 / 56 / 112 / 128) needs 2 / 1 / 2 / 3 active +// tiles, crossing tile boundaries down to a single tile and back up, leaving +// surplus build-time tiles to no-op. The reference is computed by XNNPACK so +// this medium shape stays cheap. +// +// The test sweeps the full io_storage x im2col_storage matrix (3 x 3 = 9 +// combos): io_storage = the input/output tensor storage; im2col_storage = the +// scratch storage forced into conv2d_gemm_impl. + +namespace { + +// +// Test specification +// + +struct Conv2dTestConfig { + // Conv params. + int64_t in_channels; + int64_t out_channels; + int64_t kernel_h; + int64_t kernel_w; + int64_t stride_h; + int64_t stride_w; + int64_t padding_h; + int64_t padding_w; + int64_t dilation_h; + int64_t dilation_w; + int64_t groups; // only groups == 1 is supported by conv2d_gemm_impl + bool has_bias; + + // Initial (BUILD-time, upper-bound) input spatial extents. The input is + // always [1, in_channels, init_h, init_w]; weight is + // [out_channels, in_channels, kernel_h, kernel_w]. + int64_t init_h; + int64_t init_w; + + // Storage type for the conv INPUT and OUTPUT tensors. + utils::StorageType io_storage; + + // Storage type for the im2col SCRATCH tensor (distinct from io_storage). + utils::StorageType im2col_storage; + + // Resized input (h, w) shapes to sweep after the initial build. Each MUST be + // <= (init_h, init_w) per dim (build-time upper bound). + std::vector> resize_hw; +}; + +const char* storage_type_name(utils::StorageType st) { + switch (st) { + case utils::kBuffer: + return "buffer"; + case utils::kTexture2D: + return "texture2d"; + case utils::kTexture3D: + return "texture3d"; + default: + return "unknown"; + } +} + +// Human-readable dump of a Conv2dTestConfig, used both as per-test SCOPED_TRACE +// context and in per-shape mismatch messages so a failing run names the exact +// config + storage that failed. +std::string to_string(const Conv2dTestConfig& cfg) { + std::ostringstream os; + os << "Conv2dTestConfig{in_ch=" << cfg.in_channels + << " out_ch=" << cfg.out_channels << " kernel=" << cfg.kernel_h << "x" + << cfg.kernel_w << " stride=" << cfg.stride_h << "x" << cfg.stride_w + << " pad=" << cfg.padding_h << "x" << cfg.padding_w + << " dilation=" << cfg.dilation_h << "x" << cfg.dilation_w + << " groups=" << cfg.groups << " has_bias=" << (cfg.has_bias ? "1" : "0") + << " init=" << cfg.init_h << "x" << cfg.init_w + << " io_storage=" << storage_type_name(cfg.io_storage) + << " im2col_storage=" << storage_type_name(cfg.im2col_storage) + << " resize_hw=["; + for (size_t i = 0; i < cfg.resize_hw.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << cfg.resize_hw[i].first << "x" << cfg.resize_hw[i].second; + } + os << "]}"; + return os.str(); +} + +int64_t conv_out_dim( + int64_t in, + int64_t kernel, + int64_t stride, + int64_t padding, + int64_t dilation) { + return (in + 2 * padding - dilation * (kernel - 1) - 1) / stride + 1; +} + +std::vector rand_floats(size_t n, unsigned seed) { + std::mt19937 gen(seed); + std::uniform_real_distribution dist(-1.0f, 1.0f); + std::vector data(n); + std::generate(data.begin(), data.end(), [&]() { return dist(gen); }); + return data; +} + +size_t numel(const std::vector& sizes) { + size_t n = 1; + for (auto s : sizes) { + n *= static_cast(s); + } + return n; +} + +std::vector to_int32(const std::vector& v) { + return std::vector(v.begin(), v.end()); +} + +// +// Reference: XNNPACK f32 conv2d (fast). XNNPACK is NHWC with OHWI weights; the +// graph I/O is NCHW (channels-packed), so this converts in -> NHWC and weight +// -> OHWI before the op and the output NHWC -> NCHW after, returning an +// NCHW-order reference to compare against the staging read-back. +// +// (A naive nested-loop reference was the reason large shapes couldn't be +// tested; XNNPACK runs the reference fast enough that the build shape can be a +// true upper bound without the CPU reference dominating runtime.) +// +std::vector conv2d_ref_xnnpack( + const std::vector& input_nchw, // [C_in, H, W] + const std::vector& weight_nchw, // [C_out, C_in, K_h, K_w] + const std::vector& bias, // [C_out] or empty + const Conv2dTestConfig& cfg, + int64_t H_in, + int64_t W_in) { + const int64_t C_in = cfg.in_channels; + const int64_t C_out = cfg.out_channels; + const int64_t K_h = cfg.kernel_h; + const int64_t K_w = cfg.kernel_w; + const int64_t H_out = + conv_out_dim(H_in, K_h, cfg.stride_h, cfg.padding_h, cfg.dilation_h); + const int64_t W_out = + conv_out_dim(W_in, K_w, cfg.stride_w, cfg.padding_w, cfg.dilation_w); + + // NCHW -> NHWC input. + std::vector input_nhwc(static_cast(C_in * H_in * W_in)); + for (int64_t c = 0; c < C_in; ++c) { + for (int64_t h = 0; h < H_in; ++h) { + for (int64_t w = 0; w < W_in; ++w) { + input_nhwc[static_cast((h * W_in + w) * C_in + c)] = + input_nchw[static_cast(c * (H_in * W_in) + h * W_in + w)]; + } + } + } + + // NCHW -> OHWI weight ([C_out, C_in, K_h, K_w] -> [C_out, K_h, K_w, C_in]). + std::vector weight_ohwi(static_cast(C_out * K_h * K_w * C_in)); + for (int64_t co = 0; co < C_out; ++co) { + for (int64_t ci = 0; ci < C_in; ++ci) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + weight_ohwi[static_cast( + ((co * K_h + kh) * K_w + kw) * C_in + ci)] = + weight_nchw[static_cast( + ((co * C_in + ci) * K_h + kh) * K_w + kw)]; + } + } + } + } + + EXPECT_EQ(xnn_initialize(/*allocator=*/nullptr), xnn_status_success); + + // XNNPACK failures throw rather than return a zeroed result: a silent zero + // output would masquerade as a Vulkan-vs-reference mismatch downstream. The + // operator handle is deleted before throwing so it does not leak on the + // error path. + xnn_operator_t op = nullptr; + const float out_min = -std::numeric_limits::infinity(); + const float out_max = std::numeric_limits::infinity(); + const xnn_status create_status = xnn_create_convolution2d_nhwc_f32( + static_cast(cfg.padding_h), + static_cast(cfg.padding_w), + static_cast(cfg.padding_h), + static_cast(cfg.padding_w), + static_cast(K_h), + static_cast(K_w), + static_cast(cfg.stride_h), + static_cast(cfg.stride_w), + static_cast(cfg.dilation_h), + static_cast(cfg.dilation_w), + /*groups=*/1, + /*group_input_channels=*/static_cast(C_in), + /*group_output_channels=*/static_cast(C_out), + /*input_channel_stride=*/static_cast(C_in), + /*output_channel_stride=*/static_cast(C_out), + weight_ohwi.data(), + bias.empty() ? nullptr : bias.data(), + out_min, + out_max, + /*flags=*/0, + /*code_cache=*/nullptr, + /*weights_cache=*/nullptr, + &op); + if (create_status != xnn_status_success || op == nullptr) { + xnn_delete_operator(op); + throw std::runtime_error( + "xnn_create_convolution2d_nhwc_f32 failed with status " + + std::to_string(static_cast(create_status))); + } + + size_t workspace_size = 0; + size_t workspace_alignment = 0; + size_t out_h = 0; + size_t out_w = 0; + const xnn_status reshape_status = xnn_reshape_convolution2d_nhwc_f32( + op, + /*batch_size=*/1, + static_cast(H_in), + static_cast(W_in), + &workspace_size, + &workspace_alignment, + &out_h, + &out_w, + /*threadpool=*/nullptr); + if (reshape_status != xnn_status_success) { + xnn_delete_operator(op); + throw std::runtime_error( + "xnn_reshape_convolution2d_nhwc_f32 failed with status " + + std::to_string(static_cast(reshape_status))); + } + + std::vector output_nhwc(static_cast(C_out * H_out * W_out)); + // XNN_ALLOCATION_ALIGNMENT-aligned workspace (a bare vector is only + // max_align_t aligned). + std::vector workspace(workspace_size + workspace_alignment); + void* ws_ptr = workspace.data(); + if (workspace_alignment > 0) { + const uintptr_t addr = reinterpret_cast(ws_ptr); + const uintptr_t aligned = + (addr + workspace_alignment - 1) & ~(workspace_alignment - 1); + ws_ptr = reinterpret_cast(aligned); + } + + const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32( + op, ws_ptr, input_nhwc.data(), output_nhwc.data()); + if (setup_status != xnn_status_success) { + xnn_delete_operator(op); + throw std::runtime_error( + "xnn_setup_convolution2d_nhwc_f32 failed with status " + + std::to_string(static_cast(setup_status))); + } + const xnn_status run_status = xnn_run_operator(op, /*threadpool=*/nullptr); + if (run_status != xnn_status_success) { + xnn_delete_operator(op); + throw std::runtime_error( + "xnn_run_operator failed with status " + + std::to_string(static_cast(run_status))); + } + xnn_delete_operator(op); + + // NHWC -> NCHW output. + std::vector output_nchw(static_cast(C_out * H_out * W_out)); + for (int64_t c = 0; c < C_out; ++c) { + for (int64_t h = 0; h < H_out; ++h) { + for (int64_t w = 0; w < W_out; ++w) { + output_nchw[static_cast(c * (H_out * W_out) + h * W_out + w)] = + output_nhwc[static_cast((h * W_out + w) * C_out + c)]; + } + } + } + return output_nchw; +} + +// +// Graph construction +// + +// Handles needed to drive a built conv graph through resizes. ComputeGraph is +// move-only; hold it by value alongside the input/output handles. +struct ConvGraph { + ComputeGraph graph; + IOValueRef input; + ValueRef staging_out; + std::vector weight_nchw; // kept alive for the reference + std::vector bias; // empty if cfg.has_bias == false +}; + +// Build the tiled-im2col conv graph at cfg's INITIAL (upper-bound) input shape. +ConvGraph build_graph(const Conv2dTestConfig& cfg) { + GraphConfig graph_config; + // Force resize fns to run on every execute() (they also run because the input + // shape changes; this just matches how the runtime exercises the path). + graph_config.force_resize = true; + + std::vector weight_nchw = rand_floats( + numel({cfg.out_channels, cfg.in_channels, cfg.kernel_h, cfg.kernel_w}), + 11); + std::vector bias = cfg.has_bias + ? rand_floats(static_cast(cfg.out_channels), 22) + : std::vector{}; + + ConvGraph cg{ + ComputeGraph(graph_config), + IOValueRef{}, + kDummyValueRef, + std::move(weight_nchw), + std::move(bias)}; + ComputeGraph& graph = cg.graph; + + // Conv requires channels-packed input/output (check_conv_args). io_storage + // selects the input/output tensor storage. Build at the upper-bound (init) + // shape so the fixed build-time tile count covers every resized shape. + cg.input = graph.add_input_tensor( + {1, cfg.in_channels, cfg.init_h, cfg.init_w}, + vkapi::kFloat, + cfg.io_storage, + utils::kChannelsPacked); + const ValueRef r_weight = graph.add_tensorref( + {cfg.out_channels, cfg.in_channels, cfg.kernel_h, cfg.kernel_w}, + vkapi::kFloat, + cg.weight_nchw.data()); + + ValueRef r_bias = kDummyValueRef; + if (cfg.has_bias) { + r_bias = + graph.add_tensorref({cfg.out_channels}, vkapi::kFloat, cg.bias.data()); + } else { + r_bias = graph.add_none(); + } + + const ValueRef r_stride = + graph.add_scalar_list({cfg.stride_h, cfg.stride_w}); + const ValueRef r_padding = + graph.add_scalar_list({cfg.padding_h, cfg.padding_w}); + const ValueRef r_dilation = + graph.add_scalar_list({cfg.dilation_h, cfg.dilation_w}); + + const int64_t H_out_max = conv_out_dim( + cfg.init_h, cfg.kernel_h, cfg.stride_h, cfg.padding_h, cfg.dilation_h); + const int64_t W_out_max = conv_out_dim( + cfg.init_w, cfg.kernel_w, cfg.stride_w, cfg.padding_w, cfg.dilation_w); + const ValueRef r_out = graph.add_tensor( + {1, cfg.out_channels, H_out_max, W_out_max}, + vkapi::kFloat, + cfg.io_storage, + utils::kChannelsPacked); + + // Route straight to conv2d_gemm_impl with a forced im2col storage so the test + // is device-independent (the registered op auto-selects storage per device). + conv2d_gemm_impl( + graph, + cg.input.value, + r_weight, + r_bias, + r_stride, + r_padding, + r_dilation, + r_out, + /*clamp_out=*/false, + /*out_min_val=*/0.0f, + /*out_max_val=*/0.0f, + cfg.im2col_storage); + + cg.staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + return cg; +} + +// +// Test driver +// + +void run_dynamic_conv2d_resize_test(const Conv2dTestConfig& cfg) { + // Attach the config to every assertion in this test (gtest prints active + // SCOPED_TRACEs on failure) so a failing run names the exact config + + // storage. + SCOPED_TRACE(to_string(cfg)); + ASSERT_EQ(cfg.groups, 1) << "conv2d_gemm_impl only supports groups == 1"; + + ConvGraph cg = build_graph(cfg); + TensorFactory tf; + + // The build shape is the upper bound; the first run is at the initial shape, + // followed by each resized shape (each <= init per dim, asserted below). + std::vector> shapes; + shapes.emplace_back(cfg.init_h, cfg.init_w); + for (const auto& hw : cfg.resize_hw) { + shapes.push_back(hw); + } + + unsigned seed = 100; + for (const auto& hw : shapes) { + const int64_t H = hw.first; + const int64_t W = hw.second; + ASSERT_LE(H, cfg.init_h) + << "resized H must be <= build-time upper bound init_h"; + ASSERT_LE(W, cfg.init_w) + << "resized W must be <= build-time upper bound init_w"; + + const int64_t H_out = conv_out_dim( + H, cfg.kernel_h, cfg.stride_h, cfg.padding_h, cfg.dilation_h); + const int64_t W_out = conv_out_dim( + W, cfg.kernel_w, cfg.stride_w, cfg.padding_w, cfg.dilation_w); + const std::vector in_shape = {1, cfg.in_channels, H, W}; + const std::vector out_shape = {1, cfg.out_channels, H_out, W_out}; + const size_t in_n = numel(in_shape); + const size_t out_n = numel(out_shape); + + std::vector x_data = rand_floats(in_n, seed++); + std::vector ref = + conv2d_ref_xnnpack(x_data, cg.weight_nchw, cg.bias, cfg, H, W); + + cg.graph.resize_input(0, in_shape); + cg.graph.propagate_resize(); + cg.graph.maybe_cast_and_copy_into_staging( + cg.input.staging, x_data.data(), in_n, vkapi::kFloat); + + cg.graph.execute(); + + std::vector vk_data(out_n); + cg.graph.maybe_cast_and_copy_from_staging( + cg.staging_out, vk_data.data(), out_n, vkapi::kFloat); + + Tensor ref_t = tf.make(to_int32(out_shape), ref); + Tensor vk_t = tf.make(to_int32(out_shape), vk_data); + EXPECT_TENSOR_CLOSE_WITH_TOL(ref_t, vk_t, 1e-3, 1e-3) + << "Mismatch at resized H=" << H << " W=" << W << " (H_out=" << H_out + << ", W_out=" << W_out << ") for " << to_string(cfg); + } +} + +// Sweep the resize test over the three im2col SCRATCH storage variants (buffer +// / texture2d / texture3d), reusing the caller's config (taken by value) and +// overriding only im2col_storage each iteration. SCOPED_TRACE / to_string +// identifies the im2col_storage of any failing variant. All three are +// supported, so this is a safe single in-process loop (no variant crashes). +void test_wrapper(Conv2dTestConfig cfg) { + for (const auto im2col_storage : + {utils::kBuffer, utils::kTexture2D, utils::kTexture3D}) { + cfg.im2col_storage = im2col_storage; + run_dynamic_conv2d_resize_test(cfg); + } +} + +} // namespace + +TEST(VulkanConv2dGemmDynamicTest, im2col_storage_sweep_resize) { + // A 128x128 build shape with C_in=64, 3x3 s1p1 at the production 16 MB budget + // tiles into 3 output-height tiles (oh_tile=56). The resize sweep + // 64/56/112/128 needs 2/1/2/3 active tiles — crossing tile boundaries down to + // a single tile and back up to the bound, all <= 128. XNNPACK computes the + // reference, so the medium shape stays cheap. + // + // io_storage (the conv input/output tensor storage) is pinned to kTexture3D: + // conv2d_gemm I/O is texture3d-only. The conv shaders declare t_in / t_out as + // texture3d and use ivec3 addressing (conv2d_im2col reads texelFetch(t_in, + // ivec3); conv2d_gemm writes imageStore(t_out, ivec3)), so a buffer-backed + // I/O tensor bound to a texture descriptor crashes the driver, and a + // texture2d channels-packed 4D tensor has a different physical layout than + // the ivec3 addressing assumes, producing numerically wrong output (both were + // confirmed on Mali + Adreno). Only the im2col SCRATCH storage is + // parameterized — the im2col / GEMM shaders DO have buffer / tex2d / tex3d + // codegen variants for it. io_storage is therefore pinned to kTexture3D (and + // the buffer / tex2d I/O combinations are intentionally NOT exercised) + // until/unless buffer / tex2d conv-I/O shader variants are added; it stays an + // explicit config field (printed by to_string) so the test is trivial to + // extend when they land. + // + // im2col_storage is overwritten by test_wrapper for each scratch-storage + // variant; the value set here is just a placeholder. + const Conv2dTestConfig cfg{ + /*in_channels=*/64, + /*out_channels=*/64, + /*kernel_h=*/3, + /*kernel_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1, + /*dilation_h=*/1, + /*dilation_w=*/1, + /*groups=*/1, + /*has_bias=*/true, + /*init_h=*/128, + /*init_w=*/128, + /*io_storage=*/utils::kTexture3D, + /*im2col_storage=*/utils::kTexture3D, + /*resize_hw=*/{{64, 128}, {56, 128}, {112, 128}, {128, 128}}}; + test_wrapper(cfg); +} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 383a2d67eaa..0a9b0743415 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -198,3 +198,10 @@ def define_common_targets(is_fbcode = False): ":test_utils", ] ) + define_test_targets( + "conv2d_gemm_dynamic_test", + extra_deps = [ + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "fbsource//xplat/third-party/XNNPACK:XNNPACK", + ] + )