Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions lib/typeprof/core/ast/control.rb
Original file line number Diff line number Diff line change
Expand Up @@ -322,44 +322,74 @@ def install0(genv)
ret = Vertex.new(self)
@pivot&.install(genv)

# case文での型絞り込みを実装
if @pivot && @pivot.is_a?(LocalVariableReadNode)
var = @pivot.var
original_vtx = @lenv.get_var(var)
# Collect modified variables across all branches
vars = []
@when_nodes.each {|wn| wn.body.modified_vars(@lenv.locals.keys, vars) }
@else_clause.modified_vars(@lenv.locals.keys, vars) if @else_clause
vars.uniq!

# ダミー変数に元の型情報を設定
@lenv.set_var(:"*pivot", original_vtx)
# Save original variable vertices
saved_vtxs = {}
vars.each do |var|
saved_vtxs[var] = @lenv.get_var(var)
end

# 各when節を実行
@when_nodes.each do |when_node|
clause_result = when_node.install(genv)
@changes.add_edge(genv, clause_result, ret)
# 元の型に戻す
@lenv.set_var(var, original_vtx)
end
# Prepare per-branch result vertices
branch_vtxs = []

# Setup pivot narrowing if applicable
pivot_var = @pivot.is_a?(LocalVariableReadNode) ? @pivot.var : nil
if pivot_var
original_pivot = @lenv.get_var(pivot_var)
@lenv.set_var(:"*pivot", original_pivot)
end

# Install each when branch
@when_nodes.each do |when_node|
# Reset variables to original for each branch
saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) }
@lenv.set_var(pivot_var, original_pivot) if pivot_var

clause_val = when_node.install(genv)
@changes.add_edge(genv, clause_val, ret)

# else節(他のwhen節で除外された後の型)
filtered_else_vtx = original_vtx.new_vertex(genv, self)
modified = {}
vars.each {|var| modified[var] = @lenv.get_var(var) }
branch_vtxs << [clause_val, modified]
end

# Install else branch
saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) }
if pivot_var
# Apply exclusion filters for else
filtered_else_vtx = original_pivot.new_vertex(genv, self)
@when_nodes.each do |when_node|
when_node.get_exclusion_conditions.each do |static_ret|
# 各when節の型を除外(negation)
filtered_else_vtx = IsAFilter.new(genv, self, filtered_else_vtx, true, static_ret).next_vtx
end
end
@lenv.set_var(var, filtered_else_vtx)
@changes.add_edge(genv, @else_clause.install(genv), ret)
@lenv.set_var(var, original_vtx)
@lenv.set_var(pivot_var, filtered_else_vtx)
end
else_val = @else_clause.install(genv)
@changes.add_edge(genv, else_val, ret)

# ダミー変数をクリア
@lenv.locals.delete(:"*pivot")
else
# pivotが変数でない場合は従来通り
@when_nodes.each do |when_node|
@changes.add_edge(genv, when_node.install(genv), ret)
else_modified = {}
vars.each {|var| else_modified[var] = @lenv.get_var(var) }
branch_vtxs << [else_val, else_modified]

# Join all branches
vars.each do |var|
joined = Vertex.new(self)
branch_vtxs.each do |branch_val, modified|
vtx = BotFilter.new(genv, self, modified[var], branch_val).next_vtx
@changes.add_edge(genv, vtx, joined)
end
@changes.add_edge(genv, @else_clause.install(genv), ret)
@lenv.set_var(var, joined)
end

# Cleanup
@lenv.locals.delete(:"*pivot") if pivot_var

ret
end
end
Expand Down
20 changes: 20 additions & 0 deletions scenario/flow/case_variable.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## update
def test(type, val)
case type
when :int
val = val.to_i
when :sym
val = val.to_sym
else
val = val
end
val
end

test(:int, "42")
test(:sym, "hello")

## assert
class Object
def test: (:int | :sym, String) -> (Integer | String | Symbol)
end
Loading