From 4e5f1b43ac4b1bc24bc602e9e19f078fda1cc042 Mon Sep 17 00:00:00 2001 From: aashu2006 Date: Thu, 2 Apr 2026 06:24:58 +0530 Subject: [PATCH 1/2] Add auto spreading for WebGPU compute dispatches and fix strands void return handling --- src/strands/strands_api.js | 14 ++++++---- src/strands/strands_codegen.js | 5 +--- src/webgpu/p5.RendererWebGPU.js | 43 ++++++++++++++++++++++++++++--- src/webgpu/shaders/compute.js | 16 +++++++----- src/webgpu/strands_wgslBackend.js | 10 ++++--- 5 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/strands/strands_api.js b/src/strands/strands_api.js index ab1453eb1c..18fac44ecd 100644 --- a/src/strands/strands_api.js +++ b/src/strands/strands_api.js @@ -219,7 +219,7 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) { const nodeData = DAG.createNodeData({ nodeType: NodeType.STATEMENT, statementType: StatementType.EARLY_RETURN, - dependsOn: [valueNode.id] + dependsOn: value !== undefined ? [valueNode.id] : [] }); const earlyReturnID = DAG.getOrCreateNode(dag, nodeData); CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID); @@ -786,17 +786,21 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) { return newStruct.id; } } + else if (!expectedReturnType.dataType || expectedReturnType.typeName?.trim() === 'void') { + return null; + } else /*if(isNativeType(expectedReturnType.typeName))*/ { - if (!expectedReturnType.dataType) { - throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`); - } const expectedTypeInfo = expectedReturnType.dataType; return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name); } } for (const { valueNode, earlyReturnID } of hook.earlyReturns) { const id = handleRetVal(valueNode); - dag.dependsOn[earlyReturnID] = [id]; + if (id !== null) { + dag.dependsOn[earlyReturnID] = [id]; + } else { + dag.dependsOn[earlyReturnID] = []; + } } rootNodeID = userReturned ? handleRetVal(userReturned) : undefined; const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`; diff --git a/src/strands/strands_codegen.js b/src/strands/strands_codegen.js index bf1ea61b4d..38e24c511e 100644 --- a/src/strands/strands_codegen.js +++ b/src/strands/strands_codegen.js @@ -64,12 +64,9 @@ export function generateShaderCode(strandsContext) { let returnType; if (hookType.returnType.properties) { returnType = structType(hookType.returnType); - } else if (hookType.returnType.typeName === 'void') { + } else if (!hookType.returnType.dataType || hookType.returnType.typeName?.trim() === 'void') { returnType = null; } else { - if (!hookType.returnType.dataType) { - throw new Error(`Missing dataType for return type ${hookType.returnType.typeName}`); - } returnType = hookType.returnType.dataType; } diff --git a/src/webgpu/p5.RendererWebGPU.js b/src/webgpu/p5.RendererWebGPU.js index 61f8580ac6..05f7167a91 100644 --- a/src/webgpu/p5.RendererWebGPU.js +++ b/src/webgpu/p5.RendererWebGPU.js @@ -3813,10 +3813,45 @@ ${hookUniformFields}} const WORKGROUP_SIZE_Y = 8; const WORKGROUP_SIZE_Z = 1; - // Calculate number of workgroups needed - const workgroupCountX = Math.ceil(x / WORKGROUP_SIZE_X); - const workgroupCountY = Math.ceil(y / WORKGROUP_SIZE_Y); - const workgroupCountZ = Math.ceil(z / WORKGROUP_SIZE_Z); + // auto spreading: if any dimension is too large or for performance optimization, + // spread total iteration count across dimensions + const totalIterations = x * y * z; + const MAX_THREADS_PER_DIM = 65535 * 8; + + let px = x; + let py = y; + let pz = z; + + // we spread if we exceed GPU limits OR if it involves a large 1D dispatch + const exceedsLimits = x > MAX_THREADS_PER_DIM || y > MAX_THREADS_PER_DIM || z > MAX_THREADS_PER_DIM; + const isLarge1D = totalIterations > 1024 && y === 1 && z === 1; + + if (exceedsLimits || isLarge1D) { + if (totalIterations > 1000000) { + // 3D cube type for extreme large counts + px = Math.ceil(Math.pow(totalIterations, 1 / 3)); + py = Math.ceil(Math.pow(totalIterations, 1 / 3)); + pz = Math.ceil(totalIterations / (px * py)); + } else { + // 2D square type for moderate large counts + px = Math.ceil(Math.sqrt(totalIterations)); + py = Math.ceil(totalIterations / px); + pz = 1; + } + + if (p5.debug || exceedsLimits) { + console.warn( + `p5.js: Compute dispatch (${x}, ${y}, ${z}) auto-spread to (${px}, ${py}, ${pz}) ` + + `to ${exceedsLimits ? 'stay within GPU limits' : 'optimize performance'}.` + ); + } + } + + shader.setUniform('uPhysicalCount', [px, py, pz]); + + const workgroupCountX = Math.ceil(px / WORKGROUP_SIZE_X); + const workgroupCountY = Math.ceil(py / WORKGROUP_SIZE_Y); + const workgroupCountZ = Math.ceil(pz / WORKGROUP_SIZE_Z); const commandEncoder = this.device.createCommandEncoder(); const passEncoder = commandEncoder.beginComputePass(); diff --git a/src/webgpu/shaders/compute.js b/src/webgpu/shaders/compute.js index 39e6146f4e..dafe356ee6 100644 --- a/src/webgpu/shaders/compute.js +++ b/src/webgpu/shaders/compute.js @@ -1,6 +1,7 @@ export const baseComputeShader = ` struct ComputeUniforms { uTotalCount: vec3, + uPhysicalCount: vec3, } @group(0) @binding(0) var uniforms: ComputeUniforms; @@ -11,16 +12,19 @@ fn main( @builtin(workgroup_id) workgroupId: vec3, @builtin(local_invocation_index) localIndex: u32 ) { - var index = vec3(globalId); + let totalIterations = u32(uniforms.uTotalCount.x) * u32(uniforms.uTotalCount.y) * u32(uniforms.uTotalCount.z); + let physicalId = globalId.x + globalId.y * (u32(uniforms.uPhysicalCount.x)) + globalId.z * (u32(uniforms.uPhysicalCount.x) * u32(uniforms.uPhysicalCount.y)); - if ( - index.x >= uniforms.uTotalCount.x || - index.y >= uniforms.uTotalCount.y || - index.z >= uniforms.uTotalCount.z - ) { + if (physicalId >= totalIterations) { return; } + var index = vec3(0); + index.x = i32(physicalId % u32(uniforms.uTotalCount.x)); + let remainingY = physicalId / u32(uniforms.uTotalCount.x); + index.y = i32(remainingY % u32(uniforms.uTotalCount.y)); + index.z = i32(remainingY / u32(uniforms.uTotalCount.y)); + HOOK_iteration(index); } `; diff --git a/src/webgpu/strands_wgslBackend.js b/src/webgpu/strands_wgslBackend.js index 32226fa83d..1a5182cde8 100644 --- a/src/webgpu/strands_wgslBackend.js +++ b/src/webgpu/strands_wgslBackend.js @@ -301,9 +301,13 @@ export const wgslBackend = { // Generate just a semicolon (unless suppressed) generationContext.write(semicolon); } else if (node.statementType === StatementType.EARLY_RETURN) { - const exprNodeID = node.dependsOn[0]; - const expr = this.generateExpression(generationContext, dag, exprNodeID); - generationContext.write(`return ${expr}${semicolon}`); + if (node.dependsOn && node.dependsOn.length > 0) { + const exprNodeID = node.dependsOn[0]; + const expr = this.generateExpression(generationContext, dag, exprNodeID); + generationContext.write(`return ${expr}${semicolon}`); + } else { + generationContext.write(`return${semicolon}`); + } } }, generateAssignment(generationContext, dag, nodeID) { From fad59601332055e7c8263f9583168d96b2ab2324 Mon Sep 17 00:00:00 2001 From: aashu2006 Date: Sat, 4 Apr 2026 11:25:57 +0530 Subject: [PATCH 2/2] Add tests for void compute hook early returns --- test/unit/webgpu/p5.Shader.js | 41 +++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/unit/webgpu/p5.Shader.js b/test/unit/webgpu/p5.Shader.js index 9c4a1fa559..6571d92edb 100644 --- a/test/unit/webgpu/p5.Shader.js +++ b/test/unit/webgpu/p5.Shader.js @@ -1228,5 +1228,46 @@ suite('WebGPU p5.Shader', function() { }); } }); + + suite('compute shaders', () => { + test('handle early return in void compute hook', async () => { + await myp5.createCanvas(5, 5, myp5.WEBGPU); + + // This test verifies that buildComputeShader and p5.compute + // correctly handle void hooks with early returns without crashing + // the strands compiler or hitting type errors. + expect(() => { + const computeShader = myp5.buildComputeShader(() => { + const id = myp5.index.x; + if (id > 10) { + return; // Early return in void hook + } + }, { myp5 }); + + myp5.compute(computeShader, 1); + }).not.toThrow(); + }); + + test('early return in void compute hook stops execution', async () => { + await myp5.createCanvas(5, 5, myp5.WEBGPU); + const data = myp5.createStorage([0]); + + const computeShader = myp5.buildComputeShader(() => { + const buf = myp5.uniformStorage(); + const id = myp5.index.x; + if (id == 0) { + buf[0] = 1.0; + return; + buf[0] = 2.0; // Should not execute + } + }, { myp5 }); + + computeShader.setUniform('buf', data); + + expect(() => { + myp5.compute(computeShader, 1); + }).not.toThrow(); + }); + }); }); });