Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/strands/ir_builders.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,17 @@ export function binaryOpNode(strandsContext, leftStrandsNode, rightArg, opCode)
let finalRightNodeID = rightStrandsNode.id;

// Check if we have to cast either node
const leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id);
const rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id);
let leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id);
let rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id);

// Update ASSIGN_ON_USE nodes to match the type of the other operand
if (leftType.baseType === BaseType.ASSIGN_ON_USE && rightType.baseType !== BaseType.ASSIGN_ON_USE) {
DAG.propagateTypeToAssignOnUse(dag, leftStrandsNode.id, rightType.baseType, rightType.dimension);
leftType = DAG.extractNodeTypeInfo(dag, leftStrandsNode.id);
} else if (rightType.baseType === BaseType.ASSIGN_ON_USE && leftType.baseType !== BaseType.ASSIGN_ON_USE) {
DAG.propagateTypeToAssignOnUse(dag, rightStrandsNode.id, leftType.baseType, leftType.dimension);
rightType = DAG.extractNodeTypeInfo(dag, rightStrandsNode.id);
}
const cast = { node: null, toType: leftType };
const bothDeferred = leftType.baseType === rightType.baseType && leftType.baseType === BaseType.DEFER;
if (bothDeferred) {
Expand Down
32 changes: 31 additions & 1 deletion src/strands/ir_dag.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { NodeTypeRequiredFields, NodeTypeToName, BasePriority, StatementType } from './ir_types';
import { NodeTypeRequiredFields, NodeTypeToName, BasePriority, StatementType, BaseType } from './ir_types';
import * as FES from './strands_FES';

/////////////////////////////////
Expand Down Expand Up @@ -81,6 +81,36 @@ export function extractNodeTypeInfo(dag, nodeID) {
};
}

// Propagate a known type to an ASSIGN_ON_USE node and all its ASSIGN_ON_USE dependencies
export function propagateTypeToAssignOnUse(dag, nodeId, baseType, dimension, visited = new Set()) {
// Avoid infinite loops
if (visited.has(nodeId)) {
return;
}
visited.add(nodeId);

const node = getNodeDataFromID(dag, nodeId);

// Only update if this node is ASSIGN_ON_USE
if (node.baseType !== BaseType.ASSIGN_ON_USE) {
return;
}

// Update this node's type
dag.baseTypes[nodeId] = baseType;
dag.dimensions[nodeId] = dimension;

// Recursively propagate to any ASSIGN_ON_USE dependencies
if (node.dependsOn && node.dependsOn.length > 0) {
for (const depId of node.dependsOn) {
const dep = getNodeDataFromID(dag, depId);
if (dep.baseType === BaseType.ASSIGN_ON_USE) {
propagateTypeToAssignOnUse(dag, depId, baseType, dimension, visited);
}
}
}
}

/////////////////////////////////
// Private functions
/////////////////////////////////
Expand Down
25 changes: 20 additions & 5 deletions src/strands/strands_phi_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,27 @@ export function createPhiNode(strandsContext, phiInputs, varName) {

// Get dimension and baseType from first valid input, skipping ASSIGN_ON_USE nodes
const inputNodes = validInputs.map((input) => DAG.getNodeDataFromID(strandsContext.dag, input.value.id));
let firstInput = inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE && input.dimension) ??
inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE) ??
inputNodes[0];

const dimension = firstInput.dimension;
const baseType = firstInput.baseType;
// Find first non-ASSIGN_ON_USE input to determine type
let typeSource = inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE && input.dimension) ??
inputNodes.find((input) => input.baseType !== BaseType.ASSIGN_ON_USE);

// If all are ASSIGN_ON_USE, fall back to first input
if (!typeSource) {
typeSource = inputNodes[0];
}

const dimension = typeSource.dimension;
const baseType = typeSource.baseType;

// Propagate the type to all ASSIGN_ON_USE inputs
if (baseType !== BaseType.ASSIGN_ON_USE) {
for (const input of inputNodes) {
if (input.baseType === BaseType.ASSIGN_ON_USE) {
DAG.propagateTypeToAssignOnUse(strandsContext.dag, input.id, baseType, dimension);
}
}
}

const nodeData = {
nodeType: NodeType.PHI,
Expand Down
234 changes: 229 additions & 5 deletions src/strands/strands_transpiler.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function replaceBinaryOperator(codeSource) {
}
}
function nodeIsUniform(ancestor) {
return ancestor.type === 'CallExpression'
return ancestor && ancestor.type === 'CallExpression'
&& (
(
// Global mode
Expand All @@ -41,7 +41,7 @@ function nodeIsUniform(ancestor) {
}

function nodeIsVarying(node) {
return node?.type === 'CallExpression'
return node && node.type === 'CallExpression'
&& (
(
// Global mode
Expand Down Expand Up @@ -1286,6 +1286,226 @@ function transformHelperFunction(functionNode) {
functionNode.body.body.push(finalReturn);
}

// Helper function to check if a function body contains .set() calls in control flow
function functionHasSetInControlFlow(functionNode) {
let hasSetInControlFlow = false;
let inControlFlow = 0;

const checkForSetCalls = {
IfStatement(node, state, c) {
inControlFlow++;
if (node.test) c(node.test, state);
if (node.consequent) c(node.consequent, state);
if (node.alternate) c(node.alternate, state);
inControlFlow--;
},
ForStatement(node, state, c) {
inControlFlow++;
if (node.init) c(node.init, state);
if (node.test) c(node.test, state);
if (node.update) c(node.update, state);
if (node.body) c(node.body, state);
inControlFlow--;
},
CallExpression(node) {
// Check if this is a .set() call
if (inControlFlow > 0 &&
node.callee?.type === 'MemberExpression' &&
node.callee?.property?.name === 'set') {
hasSetInControlFlow = true;
}
}
};

if (functionNode.body && functionNode.body.type === 'BlockStatement') {
recursive(functionNode.body, {}, checkForSetCalls);
}

return hasSetInControlFlow;
}

// Transform a function to use __setValue pattern instead of .set() calls in branches/loops
function transformFunctionSetCalls(functionNode) {
if (!functionNode.body || functionNode.body.type !== 'BlockStatement') {
return; // Can't transform arrow functions with expression bodies
}

// Track which hooks have .set() calls, mapping expression string to the actual AST node
const hooksWithSetCalls = new Map(); // exprString -> hookObjectNode

// First pass: find all hooks that have .set() calls in control flow
const findSetCalls = {
CallExpression(node) {
if (node.callee?.type === 'MemberExpression' &&
node.callee?.property?.name === 'set' &&
node.callee?.object) {
// This is something like filterColor.set(...) or myp5.filterColor.set(...)
const hookObjectNode = node.callee.object;
const exprString = escodegen.generate(hookObjectNode);
if (!hooksWithSetCalls.has(exprString)) {
hooksWithSetCalls.set(exprString, hookObjectNode);
}
}
}
};

recursive(functionNode.body, {}, findSetCalls);

if (hooksWithSetCalls.size === 0) {
return; // No .set() calls to transform
}

// For each hook with .set() calls, add intermediate variable and transform
for (const [exprString, hookObjectNode] of hooksWithSetCalls) {
// Create a safe variable name from the expression
const safeVarName = exprString.replace(/[^a-zA-Z0-9_]/g, '_');
const intermediateVarName = `__${safeVarName}_value`;

// 1. Find the .begin() call and insert intermediate variable right after it
const intermediateVarDecl = {
type: 'VariableDeclaration',
declarations: [{
type: 'VariableDeclarator',
id: { type: 'Identifier', name: intermediateVarName },
init: null
}],
kind: 'let'
};

let beginCallIndex = -1;
for (let i = 0; i < functionNode.body.body.length; i++) {
const stmt = functionNode.body.body[i];
if (stmt.type === 'ExpressionStatement' &&
stmt.expression?.type === 'CallExpression' &&
stmt.expression?.callee?.type === 'MemberExpression' &&
stmt.expression?.callee?.property?.name === 'begin') {
const beginExprString = escodegen.generate(stmt.expression.callee.object);
if (beginExprString === exprString) {
beginCallIndex = i;
break;
}
}
}

// Insert intermediate variable after .begin() if found, otherwise at the start
if (beginCallIndex !== -1) {
functionNode.body.body.splice(beginCallIndex + 1, 0, intermediateVarDecl);
} else {
functionNode.body.body.unshift(intermediateVarDecl);
}

// 2. Transform all .set() calls to assignments
const transformSetToAssignment = {
CallExpression(node, state, ancestors) {
// Check if this is a .set() call for this hook
if (node.callee?.type === 'MemberExpression' &&
node.callee?.property?.name === 'set' &&
node.callee?.object) {
const currentExprString = escodegen.generate(node.callee.object);
if (currentExprString === exprString && node.arguments.length > 0) {
// Find the parent statement
let parentStmt = null;
for (let i = ancestors.length - 1; i >= 0; i--) {
if (ancestors[i].type === 'ExpressionStatement') {
parentStmt = ancestors[i];
break;
}
}

if (parentStmt) {
// Replace the .set() call with an assignment
parentStmt.type = 'ExpressionStatement';
parentStmt.expression = {
type: 'AssignmentExpression',
operator: '=',
left: { type: 'Identifier', name: intermediateVarName },
right: node.arguments[0]
};
}
}
}
}
};

ancestor(functionNode.body, transformSetToAssignment);

// 3. Find the .end() call and insert final .set() call right before it
const finalSetCall = {
type: 'ExpressionStatement',
expression: {
type: 'CallExpression',
callee: {
type: 'MemberExpression',
object: JSON.parse(JSON.stringify(hookObjectNode)), // Deep copy the original node
property: { type: 'Identifier', name: 'set' },
computed: false
},
arguments: [{ type: 'Identifier', name: intermediateVarName }]
}
};

// Find the .end() call for this hook
let endCallIndex = -1;
for (let i = 0; i < functionNode.body.body.length; i++) {
const stmt = functionNode.body.body[i];
if (stmt.type === 'ExpressionStatement' &&
stmt.expression?.type === 'CallExpression' &&
stmt.expression?.callee?.type === 'MemberExpression' &&
stmt.expression?.callee?.property?.name === 'end') {
const endExprString = escodegen.generate(stmt.expression.callee.object);
if (endExprString === exprString) {
endCallIndex = i;
break;
}
}
}

// Insert the final .set() call before .end() if found, otherwise at the end
if (endCallIndex !== -1) {
functionNode.body.body.splice(endCallIndex, 0, finalSetCall);
} else {
// If no .end() found, insert before return statement or at the end
const lastStatement = functionNode.body.body[functionNode.body.body.length - 1];
if (lastStatement && lastStatement.type === 'ReturnStatement') {
functionNode.body.body.splice(functionNode.body.body.length - 1, 0, finalSetCall);
} else {
functionNode.body.body.push(finalSetCall);
}
}
}
}

// Main transformation pass: find and transform functions with .set() calls in control flow
function transformSetCallsInControlFlow(ast) {
const functionsToTransform = [];

// Collect functions that have .set() calls in control flow
const collectFunctions = {
ArrowFunctionExpression(node, ancestors) {
if (functionHasSetInControlFlow(node)) {
functionsToTransform.push(node);
}
},
FunctionExpression(node, ancestors) {
if (functionHasSetInControlFlow(node)) {
functionsToTransform.push(node);
}
},
FunctionDeclaration(node, ancestors) {
if (functionHasSetInControlFlow(node)) {
functionsToTransform.push(node);
}
}
};

ancestor(ast, collectFunctions);

// Transform each collected function
for (const funcNode of functionsToTransform) {
transformFunctionSetCalls(funcNode);
}
}

// Main transformation pass: find and transform helper functions with early returns
function transformHelperFunctionEarlyReturns(ast) {
const helperFunctionsToTransform = [];
Expand Down Expand Up @@ -1329,16 +1549,20 @@ export function transpileStrandsToJS(p5, sourceString, srcLocations, scope) {
ecmaVersion: 2021,
locations: srcLocations
});
// First pass: transform everything except if/for statements using normal ancestor traversal

// First pass: transform .set() calls in control flow to use intermediate variables
transformSetCallsInControlFlow(ast);

// Second pass: transform everything except if/for statements using normal ancestor traversal
const nonControlFlowCallbacks = { ...ASTCallbacks };
delete nonControlFlowCallbacks.IfStatement;
delete nonControlFlowCallbacks.ForStatement;
ancestor(ast, nonControlFlowCallbacks, undefined, { varyings: {} });

// Second pass: transform helper functions with early returns to use __returnValue pattern
// Third pass: transform helper functions with early returns to use __returnValue pattern
transformHelperFunctionEarlyReturns(ast);

// Third pass: transform if/for statements in post-order using recursive traversal
// Fourth pass: transform if/for statements in post-order using recursive traversal
const postOrderControlFlowTransform = {
IfStatement(node, state, c) {
state.inControlFlow++;
Expand Down
Loading
Loading