Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 10 additions & 72 deletions pyrefly/lib/alt/overload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use pyrefly_util::prelude::SliceExt;
use pyrefly_util::prelude::VecExt;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use starlark_map::small_map::SmallMap;
use vec1::Vec1;
use vec1::vec1;

Expand Down Expand Up @@ -51,8 +50,6 @@ struct CalledOverload<'f> {
func: &'f TargetWithTParams<Function>,
res: Type,
ctor_targs: Option<TArgs>,
/// Mapping from original partial vars to fresh copies used in this overload call.
partial_var_map: SmallMap<Var, Var>,
call_errors: ErrorCollector,
/// Maps each argument's source range to the parameter type it was matched against.
expected_types: HashMap<TextRange, Type>,
Expand Down Expand Up @@ -268,7 +265,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
func: arity_closest_overload.unwrap().0,
res: self.heap.mk_any_error(),
ctor_targs: None,
partial_var_map: SmallMap::new(),
call_errors: self.error_collector(),
expected_types: HashMap::new(),
},
Expand Down Expand Up @@ -324,7 +320,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
closest_overload = CalledOverload {
func,
ctor_targs,
partial_var_map: first_overload.partial_var_map.clone(),
expected_types,
res: self.unions(matched_overloads.into_map(|o| o.res)),
call_errors: self.error_collector(),
Expand All @@ -347,8 +342,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
{
*targs = chosen_targs;
}
self.solver()
.solve_partial_vars_from_fresh(&closest_overload.partial_var_map);
}
// Record the closest overload to power IDE services.
let mut overload_trace = |target: &TargetWithTParams<Function>| {
Expand Down Expand Up @@ -512,6 +505,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
) -> (CalledOverload<'c>, bool) {
let mut matched_overloads = Vec::with_capacity(overloads.len());
let mut closest_unmatched_overload: Option<CalledOverload<'c>> = None;
// Snapshot partial vars so we can restore them after each losing overload.
let partial_vars = self.collect_partial_vars(self_obj, args, keywords);
let snapshot = self.solver().snapshot_partial_vars(&partial_vars);
for callable in overloads {
let called_overload = self.call_overload(
callable,
Expand All @@ -525,8 +521,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
ctor_targs,
);
if called_overload.call_errors.is_empty() {
// Restore partial vars so the next overload starts clean.
self.solver().restore_partial_vars(&snapshot);
matched_overloads.push(called_overload);
} else {
// Restore partial vars — this overload may have pinned them.
self.solver().restore_partial_vars(&snapshot);
match &closest_unmatched_overload {
Some(overload)
if overload.call_errors.len() <= called_overload.call_errors.len() => {}
Expand Down Expand Up @@ -780,18 +780,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
partial_vars
}

/// Substitute fresh vars for originals in a type. This is used to generate fresh partial vars
/// for overload calls.
fn substitute_vars(ty: &mut Type, mapping: &SmallMap<Var, Var>) {
ty.transform_mut(&mut |t| {
if let Type::Var(v) = t
&& let Some(fresh) = mapping.get(v)
{
*t = Type::Var(*fresh);
}
});
}

fn call_overload<'c>(
&self,
callable: &'c TargetWithTParams<Function>,
Expand All @@ -811,63 +799,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let mut overload_ctor_targs = ctor_targs.as_ref().map(|x| (**x).clone());
let tparams = callable.0.as_deref();

// Substitute fresh vars into self_obj and Type-valued arguments. Each overload
// gets its own fresh copies so that a failing overload's constraint solving
// doesn't pin the original partial vars.
let partial_vars = self.collect_partial_vars(self_obj, args, keywords);
let partial_var_map = self
.solver()
.freshen_partial_vars(&partial_vars, self.uniques);
let owner = Owner::new();
let (self_obj, fresh_args, fresh_keywords) = if partial_var_map.is_empty() {
(self_obj.cloned(), None, None)
} else {
let self_obj = self_obj.cloned().map(|mut obj| {
Self::substitute_vars(&mut obj, &partial_var_map);
obj
});
let fresh_args = args
.iter()
.map(|arg| match arg {
CallArg::Arg(TypeOrExpr::Type(ty, range)) => {
let mut ty = (*ty).clone();
Self::substitute_vars(&mut ty, &partial_var_map);
CallArg::Arg(TypeOrExpr::Type(owner.push(ty), *range))
}
CallArg::Star(TypeOrExpr::Type(ty, _), range) => {
let mut ty = (*ty).clone();
Self::substitute_vars(&mut ty, &partial_var_map);
CallArg::Star(TypeOrExpr::Type(owner.push(ty), arg.range()), *range)
}
other => other.clone(),
})
.collect::<Vec<_>>();
let fresh_keywords = keywords
.iter()
.map(|kw| match &kw.value {
TypeOrExpr::Type(ty, range) => {
let mut ty = (*ty).clone();
Self::substitute_vars(&mut ty, &partial_var_map);
CallKeyword {
range: kw.range,
arg: kw.arg,
value: TypeOrExpr::Type(owner.push(ty), *range),
}
}
_ => kw.clone(),
})
.collect::<Vec<_>>();
(self_obj, Some(fresh_args), Some(fresh_keywords))
};

let call_errors = self.error_collector();
let (res, specialization_errors, expected_types) = self.callable_infer(
callable.1.signature.clone(),
Some(&metadata.kind),
tparams,
self_obj,
fresh_args.as_deref().unwrap_or(args),
fresh_keywords.as_deref().unwrap_or(keywords),
self_obj.cloned(),
args,
keywords,
arguments_range,
errors,
&call_errors,
Expand All @@ -886,7 +825,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
func: callable,
res,
ctor_targs: overload_ctor_targs,
partial_var_map,
call_errors,
expected_types,
}
Expand Down
82 changes: 47 additions & 35 deletions pyrefly/lib/solver/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,18 @@ pub enum PinError {
UnfinishedQuantified(Quantified),
}

#[derive(Debug)]
/// Opaque snapshot of partial variable states, used to save and restore
/// partial vars during speculative overload resolution.
pub struct PartialVarSnapshot(Vec<(Var, Variable)>);

impl fmt::Debug for PartialVarSnapshot {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartialVarSnapshot")
.field("len", &self.0.len())
.finish()
}
}

pub struct Solver {
variables: Mutex<Variables>,
instantiation_errors: RwLock<SmallMap<Var, TypeVarSpecializationError>>,
Expand All @@ -348,6 +359,12 @@ pub struct Solver {
pub spec_compliant_overloads: bool,
}

impl fmt::Debug for Solver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Solver").finish_non_exhaustive()
}
}

impl Display for Solver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (x, y) in self.variables.lock().iter() {
Expand Down Expand Up @@ -456,44 +473,39 @@ impl Solver {
)
}

/// Create fresh copies of all partial variables in the given list.
/// Returns a mapping from original vars to their fresh copies.
/// Used during overload resolution to prevent one overload's constraint
/// solving from contaminating other overloads' partial variables.
pub fn freshen_partial_vars(
&self,
vars: &[Var],
uniques: &UniqueFactory,
) -> SmallMap<Var, Var> {
let mut fresh = SmallMap::with_capacity(vars.len());
let mut lock = self.variables.lock();
for v in vars {
let cloned = match &*lock.get(*v) {
/// Snapshot the current state of partial variables so they can be restored
/// after a speculative overload call.
pub fn snapshot_partial_vars(&self, vars: &[Var]) -> PartialVarSnapshot {
let lock = self.variables.lock();
let entries = vars
.iter()
.filter_map(|v| {
let var = lock.get(*v);
match &*var {
Variable::PartialContained(range) => {
Some((*v, Variable::PartialContained(*range)))
}
Variable::PartialQuantified(q) => {
Some((*v, Variable::PartialQuantified(q.clone())))
}
_ => None,
}
})
.collect();
PartialVarSnapshot(entries)
}

/// Restore partial variables to a previously snapshotted state.
/// This undoes any pinning/solving that occurred during a speculative overload call.
pub fn restore_partial_vars(&self, snapshot: &PartialVarSnapshot) {
let lock = self.variables.lock();
for (v, state) in &snapshot.0 {
let replacement = match state {
Variable::PartialContained(range) => Variable::PartialContained(*range),
Variable::PartialQuantified(q) => Variable::PartialQuantified(q.clone()),
_ => continue,
};
let fresh_var = Var::new(uniques);
fresh.insert(*v, fresh_var);
lock.insert_fresh(fresh_var, cloned);
}
fresh
}

/// Given an original->fresh mapping of partial vars, transfer solutions
/// from fresh partial vars to the originals. Used during overload
/// resolution to set partial vars to the solutions found during the
/// winning overload call. If a fresh var is still unsolved, leave the
/// original as-is.
pub fn solve_partial_vars_from_fresh(&self, mapping: &SmallMap<Var, Var>) {
let lock = self.variables.lock();
for (original, fresh) in mapping {
let fresh_var = lock.get(*fresh);
if let Variable::Answer(fresh_type) = &*fresh_var {
let fresh_type = fresh_type.clone();
drop(fresh_var);
*lock.get_mut(*original) = Variable::Answer(fresh_type.clone());
}
*lock.get_mut(*v) = replacement;
}
}

Expand Down
Loading