From c1b93bdccac0f7fe040311ad013e8dad4647d2ee Mon Sep 17 00:00:00 2001 From: Yusuke Endoh Date: Fri, 3 Apr 2026 19:30:55 +0900 Subject: [PATCH] Isolate variable types across case/when branches Variables modified in one when branch no longer leak into other branches. Each branch starts from the original variable state, and all branches are joined after the case statement, similar to how if/unless handles variable branching. Co-Authored-By: Claude Opus 4.6 (1M context) --- lib/typeprof/core/ast/control.rb | 82 ++++++++++++++++++++++---------- scenario/flow/case_variable.rb | 20 ++++++++ 2 files changed, 76 insertions(+), 26 deletions(-) create mode 100644 scenario/flow/case_variable.rb diff --git a/lib/typeprof/core/ast/control.rb b/lib/typeprof/core/ast/control.rb index 82e1880d1..2e64bdec6 100644 --- a/lib/typeprof/core/ast/control.rb +++ b/lib/typeprof/core/ast/control.rb @@ -322,44 +322,74 @@ def install0(genv) ret = Vertex.new(self) @pivot&.install(genv) - # case文での型絞り込みを実装 - if @pivot && @pivot.is_a?(LocalVariableReadNode) - var = @pivot.var - original_vtx = @lenv.get_var(var) + # Collect modified variables across all branches + vars = [] + @when_nodes.each {|wn| wn.body.modified_vars(@lenv.locals.keys, vars) } + @else_clause.modified_vars(@lenv.locals.keys, vars) if @else_clause + vars.uniq! - # ダミー変数に元の型情報を設定 - @lenv.set_var(:"*pivot", original_vtx) + # Save original variable vertices + saved_vtxs = {} + vars.each do |var| + saved_vtxs[var] = @lenv.get_var(var) + end - # 各when節を実行 - @when_nodes.each do |when_node| - clause_result = when_node.install(genv) - @changes.add_edge(genv, clause_result, ret) - # 元の型に戻す - @lenv.set_var(var, original_vtx) - end + # Prepare per-branch result vertices + branch_vtxs = [] + + # Setup pivot narrowing if applicable + pivot_var = @pivot.is_a?(LocalVariableReadNode) ? @pivot.var : nil + if pivot_var + original_pivot = @lenv.get_var(pivot_var) + @lenv.set_var(:"*pivot", original_pivot) + end + + # Install each when branch + @when_nodes.each do |when_node| + # Reset variables to original for each branch + saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) } + @lenv.set_var(pivot_var, original_pivot) if pivot_var + + clause_val = when_node.install(genv) + @changes.add_edge(genv, clause_val, ret) - # else節(他のwhen節で除外された後の型) - filtered_else_vtx = original_vtx.new_vertex(genv, self) + modified = {} + vars.each {|var| modified[var] = @lenv.get_var(var) } + branch_vtxs << [clause_val, modified] + end + + # Install else branch + saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) } + if pivot_var + # Apply exclusion filters for else + filtered_else_vtx = original_pivot.new_vertex(genv, self) @when_nodes.each do |when_node| when_node.get_exclusion_conditions.each do |static_ret| - # 各when節の型を除外(negation) filtered_else_vtx = IsAFilter.new(genv, self, filtered_else_vtx, true, static_ret).next_vtx end end - @lenv.set_var(var, filtered_else_vtx) - @changes.add_edge(genv, @else_clause.install(genv), ret) - @lenv.set_var(var, original_vtx) + @lenv.set_var(pivot_var, filtered_else_vtx) + end + else_val = @else_clause.install(genv) + @changes.add_edge(genv, else_val, ret) - # ダミー変数をクリア - @lenv.locals.delete(:"*pivot") - else - # pivotが変数でない場合は従来通り - @when_nodes.each do |when_node| - @changes.add_edge(genv, when_node.install(genv), ret) + else_modified = {} + vars.each {|var| else_modified[var] = @lenv.get_var(var) } + branch_vtxs << [else_val, else_modified] + + # Join all branches + vars.each do |var| + joined = Vertex.new(self) + branch_vtxs.each do |branch_val, modified| + vtx = BotFilter.new(genv, self, modified[var], branch_val).next_vtx + @changes.add_edge(genv, vtx, joined) end - @changes.add_edge(genv, @else_clause.install(genv), ret) + @lenv.set_var(var, joined) end + # Cleanup + @lenv.locals.delete(:"*pivot") if pivot_var + ret end end diff --git a/scenario/flow/case_variable.rb b/scenario/flow/case_variable.rb new file mode 100644 index 000000000..e74f0d6e8 --- /dev/null +++ b/scenario/flow/case_variable.rb @@ -0,0 +1,20 @@ +## update +def test(type, val) + case type + when :int + val = val.to_i + when :sym + val = val.to_sym + else + val = val + end + val +end + +test(:int, "42") +test(:sym, "hello") + +## assert +class Object + def test: (:int | :sym, String) -> (Integer | String | Symbol) +end