Skip to content

Commit 922adff

Browse files
committed
Preserve Origins in Ternary Simplifications
1 parent 8805dab commit 922adff

3 files changed

Lines changed: 54 additions & 4 deletions

File tree

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,20 @@ private static ValDerivationNode foldUnary(ValDerivationNode node) {
235235
*/
236236
private static ValDerivationNode foldIte(ValDerivationNode node) {
237237
Ite iteExp = (Ite) node.getValue();
238+
DerivationNode parent = node.getOrigin();
238239

239-
ValDerivationNode condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
240-
ValDerivationNode thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
241-
ValDerivationNode elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
240+
ValDerivationNode condNode;
241+
ValDerivationNode thenNode;
242+
ValDerivationNode elseNode;
243+
if (parent instanceof IteDerivationNode iteOrigin) {
244+
condNode = fold(iteOrigin.getCondition());
245+
thenNode = fold(iteOrigin.getThenBranch());
246+
elseNode = fold(iteOrigin.getElseBranch());
247+
} else {
248+
condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
249+
thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
250+
elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
251+
}
242252

243253
Expression condition = condNode.getValue();
244254
Expression thenExp = thenNode.getValue();

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import liquidjava.rj_language.ast.Enum;
55
import liquidjava.rj_language.ast.Expression;
66
import liquidjava.rj_language.ast.FunctionInvocation;
7+
import liquidjava.rj_language.ast.GroupExpression;
8+
import liquidjava.rj_language.ast.Ite;
79
import liquidjava.rj_language.ast.UnaryExpression;
810
import liquidjava.rj_language.ast.Var;
911
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
@@ -101,6 +103,28 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
101103
: new ValDerivationNode(cloned, null);
102104
}
103105

106+
// lift ternary origin
107+
if (exp instanceof Ite ite) {
108+
ValDerivationNode condition = propagateRecursive(ite.getCondition(), subs, varOrigins);
109+
ValDerivationNode thenBranch = propagateRecursive(ite.getThen(), subs, varOrigins);
110+
ValDerivationNode elseBranch = propagateRecursive(ite.getElse(), subs, varOrigins);
111+
Ite cloned = (Ite) ite.clone();
112+
cloned.setChild(0, condition.getValue());
113+
cloned.setChild(1, thenBranch.getValue());
114+
cloned.setChild(2, elseBranch.getValue());
115+
116+
return (condition.getOrigin() != null || thenBranch.getOrigin() != null || elseBranch.getOrigin() != null)
117+
? new ValDerivationNode(cloned, new IteDerivationNode(condition, thenBranch, elseBranch))
118+
: new ValDerivationNode(cloned, null);
119+
}
120+
121+
if (exp instanceof GroupExpression group && group.getChildren().size() == 1) {
122+
ValDerivationNode child = propagateRecursive(group.getExpression(), subs, varOrigins);
123+
GroupExpression cloned = (GroupExpression) group.clone();
124+
cloned.setChild(0, child.getValue());
125+
return new ValDerivationNode(cloned, child.getOrigin());
126+
}
127+
104128
// recursively propagate children
105129
if (exp.hasChildren()) {
106130
Expression propagated = exp.clone();
@@ -163,4 +187,4 @@ private static void extractVarOrigins(ValDerivationNode node, Map<String, Deriva
163187
extractVarOrigins(valOrigin, varOrigins);
164188
}
165189
}
166-
}
190+
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,22 @@ void testIteConditionUsesEqualityFromConjunction() {
467467
"mode == 1 should make the mode == 2 ternary condition false");
468468
}
469469

470+
@Test
471+
void testIteConditionKeepsPropagatedVariableOrigin() {
472+
Expression expr = parse("mode == 1 && (mode == 2 ? explicit(param) : start(param))");
473+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
474+
475+
assertNotNull(result.getOrigin(), "ITE simplification should record the selected branch");
476+
IteDerivationNode iteOrigin = (IteDerivationNode) result.getOrigin();
477+
ValDerivationNode condition = iteOrigin.getCondition();
478+
BinaryDerivationNode equality = (BinaryDerivationNode) condition.getOrigin();
479+
ValDerivationNode left = equality.getLeft();
480+
481+
assertEquals("1", left.getValue().toString());
482+
assertDerivationEquals(new VarDerivationNode("mode"), left.getOrigin(),
483+
"Propagated condition value should come from the mode parameter");
484+
}
485+
470486
@Test
471487
void testByteAliasExpansion() {
472488
String sut = "Byte(b)";

0 commit comments

Comments
 (0)