From d8ba7bc59b2cef4c715ca7c3d0f0ac9f08bdd376 Mon Sep 17 00:00:00 2001 From: Dave Pagurek Date: Sat, 28 Feb 2026 09:56:18 -0500 Subject: [PATCH] Handle strands set() calls in branches and loops --- src/strands/ir_builders.js | 13 +- src/strands/ir_dag.js | 32 +++- src/strands/strands_phi_utils.js | 25 +++- src/strands/strands_transpiler.js | 234 +++++++++++++++++++++++++++++- test/unit/webgl/p5.Shader.js | 70 +++++++++ 5 files changed, 361 insertions(+), 13 deletions(-) diff --git a/src/strands/ir_builders.js b/src/strands/ir_builders.js index 7aef73058c..efe3908d75 100644 --- a/src/strands/ir_builders.js +++ b/src/strands/ir_builders.js @@ -86,8 +86,17 @@ export function binaryOpNode(strandsContext, leftStrandsNode, rightArg, opCode) let finalRightNodeID = rightStrandsNode.id; // Check if we have to cast either node - const leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id); - const rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id); + let leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id); + let rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id); + + // Update ASSIGN_ON_USE nodes to match the type of the other operand + if (leftType.baseType === BaseType.ASSIGN_ON_USE && rightType.baseType !== BaseType.ASSIGN_ON_USE) { + DAG.propagateTypeToAssignOnUse(dag, leftStrandsNode.id, rightType.baseType, rightType.dimension); + leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id); + } else if (rightType.baseType === BaseType.ASSIGN_ON_USE && leftType.baseType !== BaseType.ASSIGN_ON_USE) { + DAG.propagateTypeToAssignOnUse(dag, rightStrandsNode.id, leftType.baseType, leftType.dimension); + rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id); + } const cast = { node: null, toType: leftType }; const bothDeferred = leftType.baseType === rightType.baseType && leftType.baseType === BaseType.DEFER; if (bothDeferred) { diff --git a/src/strands/ir_dag.js b/src/strands/ir_dag.js index 45eadb473a..b5a19c1cf6 100644 --- a/src/strands/ir_dag.js +++ b/src/strands/ir_dag.js @@ -1,4 +1,4 @@ -import { NodeTypeRequiredFields, NodeTypeToName, BasePriority, StatementType } from './ir_types'; +import { NodeTypeRequiredFields, NodeTypeToName, BasePriority, StatementType, BaseType } from './ir_types'; import * as FES from './strands_FES'; ///////////////////////////////// @@ -81,6 +81,36 @@ export function extractNodeTypeInfo(dag, nodeID) { }; } +// Propagate a known type to an ASSIGN_ON_USE node and all its ASSIGN_ON_USE dependencies +export function propagateTypeToAssignOnUse(dag, nodeId, baseType, dimension, visited = new Set()) { + // Avoid infinite loops + if (visited.has(nodeId)) { + return; + } + visited.add(nodeId); + + const node = getNodeDataFromID(dag, nodeId); + + // Only update if this node is ASSIGN_ON_USE + if (node.baseType !== BaseType.ASSIGN_ON_USE) { + return; + } + + // Update this node's type + dag.baseTypes[nodeId] = baseType; + dag.dimensions[nodeId] = dimension; + + // Recursively propagate to any ASSIGN_ON_USE dependencies + if (node.dependsOn && node.dependsOn.length > 0) { + for (const depId of node.dependsOn) { + const dep = getNodeDataFromID(dag, depId); + if (dep.baseType === BaseType.ASSIGN_ON_USE) { + propagateTypeToAssignOnUse(dag, depId, baseType, dimension, visited); + } + } + } +} + ///////////////////////////////// // Private functions ///////////////////////////////// diff --git a/src/strands/strands_phi_utils.js b/src/strands/strands_phi_utils.js index e88804f649..86bcc7b59e 100644 --- a/src/strands/strands_phi_utils.js +++ b/src/strands/strands_phi_utils.js @@ -11,12 +11,27 @@ export function createPhiNode(strandsContext, phiInputs, varName) { // Get dimension and baseType from first valid input, skipping ASSIGN_ON_USE nodes const inputNodes = validInputs.map((input) => DAG.getNodeDataFromID(strandsContext.dag, input.value.id)); - let firstInput = inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE && input.dimension) ?? - inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE) ?? - inputNodes[0]; - const dimension = firstInput.dimension; - const baseType = firstInput.baseType; + // Find first non-ASSIGN_ON_USE input to determine type + let typeSource = inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE && input.dimension) ?? + inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE); + + // If all are ASSIGN_ON_USE, fall back to first input + if (!typeSource) { + typeSource = inputNodes[0]; + } + + const dimension = typeSource.dimension; + const baseType = typeSource.baseType; + + // Propagate the type to all ASSIGN_ON_USE inputs + if (baseType !== BaseType.ASSIGN_ON_USE) { + for (const input of inputNodes) { + if (input.baseType === BaseType.ASSIGN_ON_USE) { + DAG.propagateTypeToAssignOnUse(strandsContext.dag, input.id, baseType, dimension); + } + } + } const nodeData = { nodeType: NodeType.PHI, diff --git a/src/strands/strands_transpiler.js b/src/strands/strands_transpiler.js index fce1f961f2..836a177c6c 100644 --- a/src/strands/strands_transpiler.js +++ b/src/strands/strands_transpiler.js @@ -26,7 +26,7 @@ function replaceBinaryOperator(codeSource) { } } function nodeIsUniform(ancestor) { - return ancestor.type === 'CallExpression' + return ancestor && ancestor.type === 'CallExpression' && ( ( // Global mode @@ -41,7 +41,7 @@ function nodeIsUniform(ancestor) { } function nodeIsVarying(node) { - return node?.type === 'CallExpression' + return node && node.type === 'CallExpression' && ( ( // Global mode @@ -1286,6 +1286,226 @@ function transformHelperFunction(functionNode) { functionNode.body.body.push(finalReturn); } +// Helper function to check if a function body contains .set() calls in control flow +function functionHasSetInControlFlow(functionNode) { + let hasSetInControlFlow = false; + let inControlFlow = 0; + + const checkForSetCalls = { + IfStatement(node, state, c) { + inControlFlow++; + if (node.test) c(node.test, state); + if (node.consequent) c(node.consequent, state); + if (node.alternate) c(node.alternate, state); + inControlFlow--; + }, + ForStatement(node, state, c) { + inControlFlow++; + if (node.init) c(node.init, state); + if (node.test) c(node.test, state); + if (node.update) c(node.update, state); + if (node.body) c(node.body, state); + inControlFlow--; + }, + CallExpression(node) { + // Check if this is a .set() call + if (inControlFlow > 0 && + node.callee?.type === 'MemberExpression' && + node.callee?.property?.name === 'set') { + hasSetInControlFlow = true; + } + } + }; + + if (functionNode.body && functionNode.body.type === 'BlockStatement') { + recursive(functionNode.body, {}, checkForSetCalls); + } + + return hasSetInControlFlow; +} + +// Transform a function to use __setValue pattern instead of .set() calls in branches/loops +function transformFunctionSetCalls(functionNode) { + if (!functionNode.body || functionNode.body.type !== 'BlockStatement') { + return; // Can't transform arrow functions with expression bodies + } + + // Track which hooks have .set() calls, mapping expression string to the actual AST node + const hooksWithSetCalls = new Map(); // exprString -> hookObjectNode + + // First pass: find all hooks that have .set() calls in control flow + const findSetCalls = { + CallExpression(node) { + if (node.callee?.type === 'MemberExpression' && + node.callee?.property?.name === 'set' && + node.callee?.object) { + // This is something like filterColor.set(...) or myp5.filterColor.set(...) + const hookObjectNode = node.callee.object; + const exprString = escodegen.generate(hookObjectNode); + if (!hooksWithSetCalls.has(exprString)) { + hooksWithSetCalls.set(exprString, hookObjectNode); + } + } + } + }; + + recursive(functionNode.body, {}, findSetCalls); + + if (hooksWithSetCalls.size === 0) { + return; // No .set() calls to transform + } + + // For each hook with .set() calls, add intermediate variable and transform + for (const [exprString, hookObjectNode] of hooksWithSetCalls) { + // Create a safe variable name from the expression + const safeVarName = exprString.replace(/[^a-zA-Z0-9_]/g, '_'); + const intermediateVarName = `__${safeVarName}_value`; + + // 1. Find the .begin() call and insert intermediate variable right after it + const intermediateVarDecl = { + type: 'VariableDeclaration', + declarations: [{ + type: 'VariableDeclarator', + id: { type: 'Identifier', name: intermediateVarName }, + init: null + }], + kind: 'let' + }; + + let beginCallIndex = -1; + for (let i = 0; i < functionNode.body.body.length; i++) { + const stmt = functionNode.body.body[i]; + if (stmt.type === 'ExpressionStatement' && + stmt.expression?.type === 'CallExpression' && + stmt.expression?.callee?.type === 'MemberExpression' && + stmt.expression?.callee?.property?.name === 'begin') { + const beginExprString = escodegen.generate(stmt.expression.callee.object); + if (beginExprString === exprString) { + beginCallIndex = i; + break; + } + } + } + + // Insert intermediate variable after .begin() if found, otherwise at the start + if (beginCallIndex !== -1) { + functionNode.body.body.splice(beginCallIndex + 1, 0, intermediateVarDecl); + } else { + functionNode.body.body.unshift(intermediateVarDecl); + } + + // 2. Transform all .set() calls to assignments + const transformSetToAssignment = { + CallExpression(node, state, ancestors) { + // Check if this is a .set() call for this hook + if (node.callee?.type === 'MemberExpression' && + node.callee?.property?.name === 'set' && + node.callee?.object) { + const currentExprString = escodegen.generate(node.callee.object); + if (currentExprString === exprString && node.arguments.length > 0) { + // Find the parent statement + let parentStmt = null; + for (let i = ancestors.length - 1; i >= 0; i--) { + if (ancestors[i].type === 'ExpressionStatement') { + parentStmt = ancestors[i]; + break; + } + } + + if (parentStmt) { + // Replace the .set() call with an assignment + parentStmt.type = 'ExpressionStatement'; + parentStmt.expression = { + type: 'AssignmentExpression', + operator: '=', + left: { type: 'Identifier', name: intermediateVarName }, + right: node.arguments[0] + }; + } + } + } + } + }; + + ancestor(functionNode.body, transformSetToAssignment); + + // 3. Find the .end() call and insert final .set() call right before it + const finalSetCall = { + type: 'ExpressionStatement', + expression: { + type: 'CallExpression', + callee: { + type: 'MemberExpression', + object: JSON.parse(JSON.stringify(hookObjectNode)), // Deep copy the original node + property: { type: 'Identifier', name: 'set' }, + computed: false + }, + arguments: [{ type: 'Identifier', name: intermediateVarName }] + } + }; + + // Find the .end() call for this hook + let endCallIndex = -1; + for (let i = 0; i < functionNode.body.body.length; i++) { + const stmt = functionNode.body.body[i]; + if (stmt.type === 'ExpressionStatement' && + stmt.expression?.type === 'CallExpression' && + stmt.expression?.callee?.type === 'MemberExpression' && + stmt.expression?.callee?.property?.name === 'end') { + const endExprString = escodegen.generate(stmt.expression.callee.object); + if (endExprString === exprString) { + endCallIndex = i; + break; + } + } + } + + // Insert the final .set() call before .end() if found, otherwise at the end + if (endCallIndex !== -1) { + functionNode.body.body.splice(endCallIndex, 0, finalSetCall); + } else { + // If no .end() found, insert before return statement or at the end + const lastStatement = functionNode.body.body[functionNode.body.body.length - 1]; + if (lastStatement && lastStatement.type === 'ReturnStatement') { + functionNode.body.body.splice(functionNode.body.body.length - 1, 0, finalSetCall); + } else { + functionNode.body.body.push(finalSetCall); + } + } + } +} + +// Main transformation pass: find and transform functions with .set() calls in control flow +function transformSetCallsInControlFlow(ast) { + const functionsToTransform = []; + + // Collect functions that have .set() calls in control flow + const collectFunctions = { + ArrowFunctionExpression(node, ancestors) { + if (functionHasSetInControlFlow(node)) { + functionsToTransform.push(node); + } + }, + FunctionExpression(node, ancestors) { + if (functionHasSetInControlFlow(node)) { + functionsToTransform.push(node); + } + }, + FunctionDeclaration(node, ancestors) { + if (functionHasSetInControlFlow(node)) { + functionsToTransform.push(node); + } + } + }; + + ancestor(ast, collectFunctions); + + // Transform each collected function + for (const funcNode of functionsToTransform) { + transformFunctionSetCalls(funcNode); + } +} + // Main transformation pass: find and transform helper functions with early returns function transformHelperFunctionEarlyReturns(ast) { const helperFunctionsToTransform = []; @@ -1329,16 +1549,20 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) { ecmaVersion: 2021, locations: srcLocations }); - // First pass: transform everything except if/for statements using normal ancestor traversal + + // First pass: transform .set() calls in control flow to use intermediate variables + transformSetCallsInControlFlow(ast); + + // Second pass: transform everything except if/for statements using normal ancestor traversal const nonControlFlowCallbacks = { ...ASTCallbacks }; delete nonControlFlowCallbacks.IfStatement; delete nonControlFlowCallbacks.ForStatement; ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {} }); - // Second pass: transform helper functions with early returns to use __returnValue pattern + // Third pass: transform helper functions with early returns to use __returnValue pattern transformHelperFunctionEarlyReturns(ast); - // Third pass: transform if/for statements in post-order using recursive traversal + // Fourth pass: transform if/for statements in post-order using recursive traversal const postOrderControlFlowTransform = { IfStatement(node, state, c) { state.inControlFlow++; diff --git a/test/unit/webgl/p5.Shader.js b/test/unit/webgl/p5.Shader.js index 6241c88973..2556c1d25d 100644 --- a/test/unit/webgl/p5.Shader.js +++ b/test/unit/webgl/p5.Shader.js @@ -2057,6 +2057,76 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca assert.approximately(pixelColor[1], 127, 5); assert.approximately(pixelColor[2], 0, 5); }); + + test('handle .set() in if-else branches with flat API', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + myp5.filterColor.begin(); + let value = 1; + if (value > 0.5) { + myp5.filterColor.set([1, 0, 0, 1]); + } else { + myp5.filterColor.set([0, 1, 0, 1]); + } + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle .set() in for loop with flat API', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + myp5.filterColor.begin(); + for (let i = 0; i < 3; i++) { + if (i === 2) { + myp5.filterColor.set([i/2, 0, 0, 1]); + } + } + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); + + test('handle false .set() in if with content afterwards with flat API', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + const testShader = myp5.baseFilterShader().modify(() => { + myp5.filterColor.begin(); + let value = 1; + if (value < 0.5) { + myp5.filterColor.set([1, 0, 0, 1]); + } + + let otherValue = 0.2; + otherValue *= 2; + myp5.filterColor.set([otherValue, 0, 0, 1]); + myp5.filterColor.end(); + }, { myp5 }); + + myp5.background(255, 255, 255); + myp5.filter(testShader); + + const pixelColor = myp5.get(25, 25); + assert.approximately(pixelColor[0], 0.4 * 255, 5); + assert.approximately(pixelColor[1], 0, 5); + assert.approximately(pixelColor[2], 0, 5); + }); }); suite('p5.strands error messages', () => {