diff --git a/lib/typeprof/core/ast/control.rb b/lib/typeprof/core/ast/control.rb index 82e1880d..2e64bdec 100644 --- a/lib/typeprof/core/ast/control.rb +++ b/lib/typeprof/core/ast/control.rb @@ -322,44 +322,74 @@ def install0(genv) ret = Vertex.new(self) @pivot&.install(genv) - # case文での型絞り込みを実装 - if @pivot && @pivot.is_a?(LocalVariableReadNode) - var = @pivot.var - original_vtx = @lenv.get_var(var) + # Collect modified variables across all branches + vars = [] + @when_nodes.each {|wn| wn.body.modified_vars(@lenv.locals.keys, vars) } + @else_clause.modified_vars(@lenv.locals.keys, vars) if @else_clause + vars.uniq! - # ダミー変数に元の型情報を設定 - @lenv.set_var(:"*pivot", original_vtx) + # Save original variable vertices + saved_vtxs = {} + vars.each do |var| + saved_vtxs[var] = @lenv.get_var(var) + end - # 各when節を実行 - @when_nodes.each do |when_node| - clause_result = when_node.install(genv) - @changes.add_edge(genv, clause_result, ret) - # 元の型に戻す - @lenv.set_var(var, original_vtx) - end + # Prepare per-branch result vertices + branch_vtxs = [] + + # Setup pivot narrowing if applicable + pivot_var = @pivot.is_a?(LocalVariableReadNode) ? @pivot.var : nil + if pivot_var + original_pivot = @lenv.get_var(pivot_var) + @lenv.set_var(:"*pivot", original_pivot) + end + + # Install each when branch + @when_nodes.each do |when_node| + # Reset variables to original for each branch + saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) } + @lenv.set_var(pivot_var, original_pivot) if pivot_var + + clause_val = when_node.install(genv) + @changes.add_edge(genv, clause_val, ret) - # else節(他のwhen節で除外された後の型) - filtered_else_vtx = original_vtx.new_vertex(genv, self) + modified = {} + vars.each {|var| modified[var] = @lenv.get_var(var) } + branch_vtxs << [clause_val, modified] + end + + # Install else branch + saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) } + if pivot_var + # Apply exclusion filters for else + filtered_else_vtx = original_pivot.new_vertex(genv, self) @when_nodes.each do |when_node| when_node.get_exclusion_conditions.each do |static_ret| - # 各when節の型を除外(negation) filtered_else_vtx = IsAFilter.new(genv, self, filtered_else_vtx, true, static_ret).next_vtx end end - @lenv.set_var(var, filtered_else_vtx) - @changes.add_edge(genv, @else_clause.install(genv), ret) - @lenv.set_var(var, original_vtx) + @lenv.set_var(pivot_var, filtered_else_vtx) + end + else_val = @else_clause.install(genv) + @changes.add_edge(genv, else_val, ret) - # ダミー変数をクリア - @lenv.locals.delete(:"*pivot") - else - # pivotが変数でない場合は従来通り - @when_nodes.each do |when_node| - @changes.add_edge(genv, when_node.install(genv), ret) + else_modified = {} + vars.each {|var| else_modified[var] = @lenv.get_var(var) } + branch_vtxs << [else_val, else_modified] + + # Join all branches + vars.each do |var| + joined = Vertex.new(self) + branch_vtxs.each do |branch_val, modified| + vtx = BotFilter.new(genv, self, modified[var], branch_val).next_vtx + @changes.add_edge(genv, vtx, joined) end - @changes.add_edge(genv, @else_clause.install(genv), ret) + @lenv.set_var(var, joined) end + # Cleanup + @lenv.locals.delete(:"*pivot") if pivot_var + ret end end diff --git a/scenario/flow/case_variable.rb b/scenario/flow/case_variable.rb new file mode 100644 index 00000000..e74f0d6e --- /dev/null +++ b/scenario/flow/case_variable.rb @@ -0,0 +1,20 @@ +## update +def test(type, val) + case type + when :int + val = val.to_i + when :sym + val = val.to_sym + else + val = val + end + val +end + +test(:int, "42") +test(:sym, "hello") + +## assert +class Object + def test: (:int | :sym, String) -> (Integer | String | Symbol) +end