diff --git a/crates/emmylua_code_analysis/resources/std/global.lua b/crates/emmylua_code_analysis/resources/std/global.lua index 022b0684a..0b5c93c76 100644 --- a/crates/emmylua_code_analysis/resources/std/global.lua +++ b/crates/emmylua_code_analysis/resources/std/global.lua @@ -253,10 +253,11 @@ function pairs(t) end --- boolean), which is true if the call succeeds without errors. In such case, --- `pcall` also returns all results from the call, after this first result. In --- case of any error, `pcall` returns **false** plus the error message. ----@generic T, R, R1 ----@param f sync fun(...: T...): R1, R... +---@generic T, R +---@param f sync fun(...: T...): R... ---@param ... T... ----@return boolean, R1|string, R... +---@return_overload true, R... +---@return_overload false, string function pcall(f, ...) end --- diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs index 7b97fa9b4..ac0cbc741 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/attribute_tags.rs @@ -34,6 +34,9 @@ pub fn analyze_tag_attribute_use( (LuaAst::LuaDocTagReturn(_), LuaSemanticDeclId::Signature(_)) => { return Some(()); } + (LuaAst::LuaDocTagReturnOverload(_), LuaSemanticDeclId::Signature(_)) => { + return Some(()); + } _ => {} } } @@ -147,7 +150,8 @@ fn attribute_find_doc(comment: &LuaSyntaxNode) -> Option { LuaKind::Syntax( LuaSyntaxKind::DocTagField | LuaSyntaxKind::DocTagParam - | LuaSyntaxKind::DocTagReturn, + | LuaSyntaxKind::DocTagReturn + | LuaSyntaxKind::DocTagReturnOverload, ) => { if let Some(node) = sibling.as_node() { return Some(node.clone()); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs index 34196fc6a..84eff53be 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/tags.rs @@ -22,7 +22,7 @@ use super::{ type_def_tags::{analyze_alias, analyze_class, analyze_enum, analyze_func_generic}, type_ref_tags::{ analyze_as, analyze_cast, analyze_module, analyze_other, analyze_overload, analyze_param, - analyze_return, analyze_return_cast, analyze_see, analyze_type, + analyze_return, analyze_return_cast, analyze_return_overload, analyze_see, analyze_type, }, }; @@ -55,6 +55,9 @@ pub fn analyze_tag(analyzer: &mut DocAnalyzer, tag: LuaDocTag) -> Option<()> { LuaDocTag::Return(return_tag) => { analyze_return(analyzer, return_tag)?; } + LuaDocTag::ReturnOverload(return_overload_tag) => { + analyze_return_overload(analyzer, return_overload_tag)?; + } LuaDocTag::ReturnCast(return_cast) => { analyze_return_cast(analyzer, return_cast)?; } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index cb75dc834..64904d287 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -1,8 +1,8 @@ use emmylua_parser::{ LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaDocDescriptionOwner, LuaDocTagAs, LuaDocTagCast, LuaDocTagModule, LuaDocTagOther, LuaDocTagOverload, LuaDocTagParam, LuaDocTagReturn, - LuaDocTagReturnCast, LuaDocTagSchema, LuaDocTagSee, LuaDocTagType, LuaExpr, LuaLocalName, - LuaTokenKind, LuaVarExpr, + LuaDocTagReturnCast, LuaDocTagReturnOverload, LuaDocTagSchema, LuaDocTagSee, LuaDocTagType, + LuaExpr, LuaLocalName, LuaTokenKind, LuaVarExpr, }; use super::{ @@ -16,8 +16,8 @@ use crate::{ SignatureReturnStatus, TypeOps, compilation::analyzer::common::bind_type, db_index::{ - LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaMemberId, LuaOperator, LuaSemanticDeclId, - LuaSignatureId, LuaType, + LuaDeclId, LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaMemberId, + LuaOperator, LuaSemanticDeclId, LuaSignatureId, LuaType, }, }; use crate::{ @@ -248,29 +248,59 @@ pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Optio let description = tag .get_description() .map(|des| preprocess_description(&des.get_description_text(), None)); + let return_infos = tag + .get_info_list() + .into_iter() + .map(|(doc_type, name_token)| LuaDocReturnInfo { + name: name_token.map(|name| name.get_name_text().to_string()), + type_ref: infer_type(analyzer, doc_type), + description: description.clone(), + attributes: None, + }) + .collect::>(); + + bind_signature_return_docs(analyzer, &tag, |signature| { + signature.return_docs.extend(return_infos); + }) +} - if let Some(closure) = find_owner_closure_or_report(analyzer, &tag) { - let signature_id = LuaSignatureId::from_closure(analyzer.file_id, &closure); - let returns = tag.get_info_list(); - for (doc_type, name_token) in returns { - let name = name_token.map(|name| name.get_name_text().to_string()); - - let type_ref = infer_type(analyzer, doc_type); - let return_info = LuaDocReturnInfo { - name, - type_ref, - description: description.clone(), - attributes: None, - }; - - let signature = analyzer - .db - .get_signature_index_mut() - .get_or_create(signature_id); - signature.return_docs.push(return_info); - signature.resolve_return = SignatureReturnStatus::DocResolve; - } +pub fn analyze_return_overload( + analyzer: &mut DocAnalyzer, + tag: LuaDocTagReturnOverload, +) -> Option<()> { + let description = tag + .get_description() + .map(|des| preprocess_description(&des.get_description_text(), None)) + .filter(|des| !des.is_empty()); + let overload_info = LuaDocReturnOverloadInfo { + type_refs: tag + .get_types() + .map(|doc_type| infer_type(analyzer, doc_type)) + .collect(), + description, + }; + if overload_info.type_refs.is_empty() { + return Some(()); } + + bind_signature_return_docs(analyzer, &tag, |signature| { + signature.return_overloads.push(overload_info); + }) +} + +fn bind_signature_return_docs( + analyzer: &mut DocAnalyzer, + tag: &impl LuaAstNode, + bind: impl FnOnce(&mut crate::LuaSignature), +) -> Option<()> { + let closure = find_owner_closure_or_report(analyzer, tag)?; + let signature_id = LuaSignatureId::from_closure(analyzer.file_id, &closure); + let signature = analyzer + .db + .get_signature_index_mut() + .get_or_create(signature_id); + bind(signature); + signature.resolve_return = SignatureReturnStatus::DocResolve; Some(()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs index 871658f3a..0d2a096ec 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/bind_analyze/stats.rs @@ -1,11 +1,13 @@ use emmylua_parser::{ BinaryOperator, LuaAssignStat, LuaAst, LuaAstNode, LuaBlock, LuaBreakStat, LuaCallArgList, LuaCallExprStat, LuaDoStat, LuaExpr, LuaForRangeStat, LuaForStat, LuaFuncStat, LuaGotoStat, - LuaIfStat, LuaLabelStat, LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaVarExpr, LuaWhileStat, + LuaIfStat, LuaLabelStat, LuaLocalName, LuaLocalStat, LuaRepeatStat, LuaReturnStat, LuaVarExpr, + LuaWhileStat, }; use crate::{ - AnalyzeError, DiagnosticCode, FlowId, FlowNodeKind, LuaClosureId, LuaDeclId, + AnalyzeError, DeclMultiReturnRef, DeclMultiReturnRefAt, DiagnosticCode, FlowId, FlowNodeKind, + LuaClosureId, LuaDeclId, compilation::analyzer::flow::{ bind_analyze::{ bind_block, bind_each_child, bind_node, @@ -33,13 +35,20 @@ pub fn bind_local_stat( } } - for value in values { + for value in &values { // If there are more values than names, we still need to bind the values bind_expr(binder, value.clone(), current); } let local_flow_id = binder.create_decl(local_stat.get_position()); binder.add_antecedent(local_flow_id, current); + bind_multi_return_refs( + binder, + &get_local_decl_ids(binder, &local_names), + &values, + local_stat.get_position(), + local_flow_id, + ); local_flow_id } @@ -69,6 +78,27 @@ fn check_value_expr_is_check_expr(value_expr: LuaExpr) -> bool { } } +fn get_local_decl_ids( + binder: &FlowBinder<'_>, + local_names: &[LuaLocalName], +) -> Vec> { + local_names + .iter() + .map(|name| Some(LuaDeclId::new(binder.file_id, name.get_position()))) + .collect() +} + +fn get_var_decl_ids(binder: &FlowBinder<'_>, vars: &[LuaVarExpr]) -> Vec> { + vars.iter() + .map(|var| { + binder + .db + .get_reference_index() + .get_var_reference_decl(&binder.file_id, var.get_range()) + }) + .collect() +} + pub fn bind_assign_stat( binder: &mut FlowBinder, assign_stat: LuaAssignStat, @@ -91,10 +121,57 @@ pub fn bind_assign_stat( let assignment_kind = FlowNodeKind::Assignment(assign_stat.to_ptr()); let flow_id = binder.create_node(assignment_kind); binder.add_antecedent(flow_id, current); + bind_multi_return_refs( + binder, + &get_var_decl_ids(binder, &vars), + &values, + assign_stat.get_position(), + flow_id, + ); flow_id } +fn bind_multi_return_refs( + binder: &mut FlowBinder, + decl_ids: &[Option], + values: &[LuaExpr], + position: rowan::TextSize, + flow_id: FlowId, +) { + let tail_call = values.last().and_then(|value| match value { + LuaExpr::CallExpr(call_expr) => Some((values.len() - 1, call_expr.to_ptr())), + _ => None, + }); + + for (i, decl_id) in decl_ids.iter().enumerate() { + let Some(decl_id) = decl_id else { + continue; + }; + + let reference = tail_call.as_ref().and_then(|(last_value_idx, call_expr)| { + if i < *last_value_idx { + return None; + } + + Some(DeclMultiReturnRef { + call_expr: call_expr.clone(), + return_index: i - *last_value_idx, + }) + }); + + binder + .decl_multi_return_ref + .entry(*decl_id) + .or_default() + .push(DeclMultiReturnRefAt { + position, + flow_id, + reference, + }); + } +} + pub fn bind_call_expr_stat( binder: &mut FlowBinder, call_expr_stat: LuaCallExprStat, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs index e58ed48e4..73e6e0119 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/flow/binder.rs @@ -6,8 +6,8 @@ use rowan::TextSize; use smol_str::SmolStr; use crate::{ - AnalyzeError, DbIndex, FileId, FlowAntecedent, FlowId, FlowNode, FlowNodeKind, FlowTree, - LuaClosureId, LuaDeclId, + AnalyzeError, DbIndex, DeclMultiReturnRefAt, FileId, FlowAntecedent, FlowId, FlowNode, + FlowNodeKind, FlowTree, LuaClosureId, LuaDeclId, }; #[derive(Debug)] @@ -15,6 +15,7 @@ pub struct FlowBinder<'a> { pub db: &'a mut DbIndex, pub file_id: FileId, pub decl_bind_expr_ref: HashMap>, + pub decl_multi_return_ref: HashMap>, pub start: FlowId, pub unreachable: FlowId, pub loop_label: FlowId, @@ -36,6 +37,7 @@ impl<'a> FlowBinder<'a> { flow_nodes: Vec::new(), multiple_antecedents: Vec::new(), decl_bind_expr_ref: HashMap::new(), + decl_multi_return_ref: HashMap::new(), labels: HashMap::new(), start: FlowId::default(), unreachable: FlowId::default(), @@ -189,6 +191,7 @@ impl<'a> FlowBinder<'a> { pub fn finish(self) -> FlowTree { FlowTree::new( self.decl_bind_expr_ref, + self.decl_multi_return_ref, self.flow_nodes, self.multiple_antecedents, // self.labels, diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs index 3f11d714a..96b332018 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs @@ -31,15 +31,7 @@ pub fn analyze_chunk_return(analyzer: &mut LuaAnalyzer, chunk: LuaChunk) -> Opti .db .get_module_index_mut() .get_module_mut(analyzer.file_id)?; - match expr_type { - LuaType::Variadic(multi) => { - let ty = multi.get_type(0)?; - module_info.export_type = Some(ty.clone()); - } - _ => { - module_info.export_type = Some(expr_type); - } - } + module_info.export_type = Some(expr_type.get_result_slot_type(0).unwrap_or(expr_type)); module_info.semantic_id = semantic_id; break; } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 9da10521c..5893099e1 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -47,10 +47,8 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) }; match analyzer.infer_expr(&expr) { - Ok(mut expr_type) => { - if let LuaType::Variadic(multi) = expr_type { - expr_type = multi.get_type(0)?.clone(); - } + Ok(expr_type) => { + let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); let decl_id = LuaDeclId::new(analyzer.file_id, position); // 当`call`参数包含表时, 表可能未被分析, 需要延迟 if let LuaType::Instance(instance) = &expr_type @@ -106,12 +104,12 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) if let Some(last_expr) = last_expr { match analyzer.infer_expr(last_expr) { Ok(last_expr_type) => { - if let LuaType::Variadic(variadic) = last_expr_type { + if last_expr_type.contain_multi_return() { for i in expr_count..name_count { let name = name_list.get(i)?; let position = name.get_position(); let decl_id = LuaDeclId::new(analyzer.file_id, position); - let ret_type = variadic.get_type(i - expr_count + 1); + let ret_type = last_expr_type.get_result_slot_type(i - expr_count + 1); if let Some(ret_type) = ret_type { bind_type( analyzer.db, @@ -311,10 +309,7 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta } let expr_type = match analyzer.infer_expr(expr) { - Ok(expr_type) => match expr_type { - LuaType::Variadic(multi) => multi.get_type(0)?.clone(), - _ => expr_type, - }, + Ok(expr_type) => expr_type.get_result_slot_type(0).unwrap_or(expr_type), Err(InferFailReason::None) => LuaType::Unknown, Err(reason) => { match type_owner { @@ -367,7 +362,7 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta { match analyzer.infer_expr(last_expr) { Ok(last_expr_type) => { - if last_expr_type.is_multi_return() { + if last_expr_type.contain_multi_return() { for i in expr_count..var_count { let var = var_list.get(i)?; let type_owner = get_var_owner(analyzer, var.clone()); @@ -408,10 +403,7 @@ fn assign_merge_type_owner_and_expr_type( expr_type: &LuaType, idx: usize, ) -> Option<()> { - let mut expr_type = expr_type.clone(); - if let LuaType::Variadic(multi) = expr_type { - expr_type = multi.get_type(idx).unwrap_or(&LuaType::Nil).clone(); - } + let expr_type = expr_type.get_result_slot_type(idx).unwrap_or(LuaType::Nil); bind_type(analyzer.db, type_owner, LuaTypeCache::InferType(expr_type)); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs index 76601f498..321a20617 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs @@ -31,13 +31,9 @@ pub fn try_resolve_decl( let expr = decl.expr.clone(); let expr_type = infer_expr(db, cache, expr)?; let decl_id = decl.decl_id; - let expr_type = match &expr_type { - LuaType::Variadic(multi) => multi - .get_type(decl.ret_idx) - .cloned() - .unwrap_or(LuaType::Unknown), - _ => expr_type, - }; + let expr_type = expr_type + .get_result_slot_type(decl.ret_idx) + .unwrap_or(LuaType::Unknown); bind_type(db, decl_id.into(), LuaTypeCache::InferType(expr_type)); Ok(()) @@ -76,13 +72,9 @@ pub fn try_resolve_member( if let Some(expr) = unresolve_member.expr.clone() { let expr_type = infer_expr(db, cache, expr)?; - let expr_type = match &expr_type { - LuaType::Variadic(multi) => multi - .get_type(unresolve_member.ret_idx) - .cloned() - .unwrap_or(LuaType::Unknown), - _ => expr_type, - }; + let expr_type = expr_type + .get_result_slot_type(unresolve_member.ret_idx) + .unwrap_or(LuaType::Unknown); let member_id = unresolve_member.member_id; bind_type(db, member_id.into(), LuaTypeCache::InferType(expr_type)); @@ -174,10 +166,7 @@ pub fn try_resolve_module( ) -> ResolveResult { let expr = module.expr.clone(); let expr_type = infer_expr(db, cache, expr)?; - let expr_type = match &expr_type { - LuaType::Variadic(multi) => multi.get_type(0).cloned().unwrap_or(LuaType::Unknown), - _ => expr_type, - }; + let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); let module_info = db .get_module_index_mut() .get_module_mut(module.file_id) diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs new file mode 100644 index 000000000..5771d9442 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -0,0 +1,88 @@ +#[cfg(test)] +mod test { + use crate::VirtualWorkspace; + + #[test] + fn test_higher_order_generic_return_infer() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = wrap(wrap, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("status"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_higher_order_return_infer_keeps_concrete_callable_result() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + return true, f(...) + end + + ---@param x integer + ---@return integer + local function take_int(x) + return x + end + + ---@class Box + ---@field value integer + local box + + ok, payload = wrap(take_int, box.missing) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_higher_order_return_infer_uses_callable_constraint() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R + ---@param ... T... + ---@return R + local function call_once(f, ...) + return f(...) + end + + ---@generic U: string + ---@param n integer + ---@return U + local function constrained_return(n) + end + + result = call_once(constrained_return, 1) + "#, + ); + + assert_eq!(ws.expr_ty("result"), ws.ty("string")); + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/mod.rs b/crates/emmylua_code_analysis/src/compilation/test/mod.rs index e4d506739..799d245c5 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/mod.rs @@ -2,6 +2,7 @@ mod and_or_test; mod annotation_test; mod array_test; mod attribute_test; +mod callable_return_infer_test; mod closure_generic; mod closure_param_infer_test; mod closure_return_test; @@ -23,6 +24,8 @@ mod out_of_order; mod overload_field; mod overload_test; mod pcall_test; +mod return_overload_flow_test; +mod return_overload_generic_test; mod return_unwrap_test; mod static_cal_cmp; mod syntax_error_test; diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 74d0e485a..5cbb83f2b 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -45,4 +45,100 @@ mod test { "# )); } + + #[test] + fn test_nested_pcall_higher_order_return_shape() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@return integer + local function f() + return 1 + end + + ok, status, payload = pcall(pcall, f) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("true|false|string")); + assert_eq!(ws.expr_ty("payload"), ws.ty("string|integer|nil")); + } + + #[test] + fn test_pcall_return_overload_narrow_after_error_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@return integer + local function foo() + return 2 + end + + local ok, result = pcall(foo) + + if not ok then + error(result) + end + + a = result + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + } + + #[test] + fn test_nested_pcall_like_without_return_overload() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function safe_call(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = safe_call(safe_call, produce) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); + } + + #[test] + fn test_nested_pcall_like_without_return_overload2() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, R, R1 + ---@param f sync fun(...: T...): R1, R... + ---@param ... T... + ---@return boolean, R1|string, R... + local function pcall_like(f, ...) end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = pcall_like(pcall_like, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs new file mode 100644 index 000000000..f0a74b505 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs @@ -0,0 +1,544 @@ +#[cfg(test)] +mod test { + use crate::{DiagnosticCode, VirtualWorkspace}; + + #[test] + fn test_return_overload_narrow_after_not() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + if not ok then + error(result) + end + + a = result + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_reassign_clears_multi_return_mapping() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local random ---@type boolean + local ok, result = pick(cond, 1, "error") + result = random and 1 or "override" + + if not ok then + error(result) + end + + f = result + "#, + ); + + assert_eq!(ws.expr_ty("f"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_narrow_with_swapped_operand_eq() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return "ok"|"err" + ---@return T|E + ---@return_overload "ok", T + ---@return_overload "err", E + local function pick(ok, success, failure) + if ok then + return "ok", success + end + return "err", failure + end + + local cond ---@type boolean + local tag, result = pick(cond, 1, "error") + + if "err" == tag then + error(result) + end + + d = result + "#, + ); + + assert_eq!(ws.expr_ty("d"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_narrow_with_type_guard_broad_discriminant() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return string|integer + ---@return string|boolean + ---@return_overload string, string + ---@return_overload integer, boolean + local function pick(ok) + if ok then + return "ok", "value" + end + return 1, false + end + + local cond ---@type boolean + local tag, result = pick(cond) + + if type(tag) == "string" then + string_branch = result + else + integer_branch = result + end + "#, + ); + + assert_eq!(ws.expr_ty("string_branch"), ws.ty("string")); + assert_eq!(ws.expr_ty("integer_branch"), ws.ty("boolean")); + } + + #[test] + fn test_return_overload_narrow_with_swapped_type_guard_alias() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return string|integer + ---@return string|boolean + ---@return_overload string, string + ---@return_overload integer, boolean + local function pick(ok) + if ok then + return "ok", "value" + end + return 1, false + end + + local cond ---@type boolean + local tag, result = pick(cond) + local kind = type(tag) + + if "string" == kind then + string_branch = result + else + integer_branch = result + end + "#, + ); + + assert_eq!(ws.expr_ty("string_branch"), ws.ty("string")); + assert_eq!(ws.expr_ty("integer_branch"), ws.ty("boolean")); + } + + #[test] + fn test_return_overload_narrow_with_type_guard_number_matches_integer_row() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return integer|string + ---@return integer|boolean + ---@return_overload integer, boolean + ---@return_overload string, integer + local function pick(ok) + if ok then + return 1, false + end + return "err", 2 + end + + local cond ---@type boolean + local tag, result = pick(cond) + + if type(tag) == "number" then + number_branch = result + else + string_branch = result + end + "#, + ); + + assert_eq!(ws.expr_ty("number_branch"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("string_branch"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_narrow_with_mixed_rhs_calls() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local left_ok, right_ok, right_result = pick(cond, "left-ok", "left-err"), pick(cond, 1, "right-err") + + if not left_ok then + error("left failed") + end + a = right_result + + if not right_ok then + error(right_result) + end + b = right_result + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("b"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_late_discriminant_rebind_does_not_affect_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + if not ok then + error(result) + end + + a = result + ok = cond + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_branch_reassign_should_not_override_join_mapping() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + + local ok, result = pick(cond, 1, "left-err") + if branch then + ok, result = pick(cond, "branch-ok", false) + end + + if not ok then + error(result) + end + + a = result + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_join_with_noncorrelated_origin_keeps_extra_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false, string + local function pick(ok) + if ok then + return true, 1 + end + return false, "err" + end + + ---@return false + local function as_false() + return false + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond) + + if branch then + ok, result = true, as_false() + end + + at_join = result + + if not ok then + in_error_path = result + error(result) + end + + after_guard = result + "#, + ); + + let in_error_path_ty = ws.expr_ty("in_error_path"); + assert!(ws.humanize_type(in_error_path_ty).contains("string")); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|integer")); + } + + #[test] + fn test_return_overload_branch_noncall_reassign_keeps_noncorrelated_origin() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if branch then + result = false + end + + if not ok then + error(result) + end + + after_guard = result + "#, + ); + + let after_guard_ty = ws.expr_ty("after_guard"); + let after_guard = ws.humanize_type(after_guard_ty); + assert!(after_guard.contains("false")); + assert!(after_guard.contains("integer")); + assert!(!after_guard.contains("string")); + } + + #[test] + fn test_return_overload_direct_discriminant_rebind_after_join_breaks_correlation() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if branch then + local noop = 1 + end + + ok = true + + if not ok then + error(result) + end + + after_guard = result + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("integer|string")); + } + + #[test] + fn test_swapped_literal_eq_narrow_without_return_overload() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@return "x" + local function test() + local a ---@type "x"|nil + if "x" == a then + return a + end + return "x" + end + "#, + )); + } + + #[test] + fn test_var_eq_var_narrow_right_operand_without_return_overload() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@return "x" + local function test() + local a ---@type "x" + local b ---@type "x"|nil + if a == b then + return b + end + return "x" + end + "#, + )); + } + + #[test] + fn test_return_overload_nested_clear_keeps_noncorrelated_origin() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local inner ---@type boolean + local ok, result = pick(cond, 1, "err") + + if branch then + if inner then + result = false + end + end + + if not ok then + error(result) + end + + after_guard = result + "#, + ); + + let after_guard_ty = ws.expr_ty("after_guard"); + let after_guard = ws.humanize_type(after_guard_ty); + assert!(after_guard.contains("false")); + assert!(after_guard.contains("integer")); + assert!(!after_guard.contains("string")); + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_generic_test.rs new file mode 100644 index 000000000..2b508128f --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_generic_test.rs @@ -0,0 +1,213 @@ +#[cfg(test)] +mod test { + use crate::VirtualWorkspace; + + #[test] + fn test_higher_order_generic_return_infer_with_return_overload() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return_overload true, R... + ---@return_overload false, string + local function wrap(f, ...) + return true, f(...) + end + + ---@return integer + local function produce() + return 1 + end + + ok, status, payload = wrap(wrap, produce) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("status"), ws.ty("false|string|true")); + assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string|nil")); + } + + #[test] + fn test_return_overload_variadic_tail_keeps_deep_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return_overload true, R... + ---@return_overload false, string + local function wrap(f, ...) + return true, f(...) + end + + ---@param n integer + ---@return integer, string, boolean + local function produce(n) + return n, tostring(n), n > 0 + end + + ok, first, second, third = wrap(produce, 1) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("first"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("second"), ws.ty("string|nil")); + assert_eq!(ws.expr_ty("third"), ws.ty("boolean|nil")); + } + + #[test] + fn test_return_overload_variadic_tpl_tail_pads_missing_slots_with_nil() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return_overload true, R... + ---@return_overload false, string + local function wrap(f, ...) + return true, f(...) + end + + ---@param n integer + ---@return integer, string, boolean + local function produce(n) + return n, tostring(n), n > 0 + end + + ok, first, second, third = wrap(produce, 1) + "#, + ); + + assert_eq!(ws.expr_ty("third"), ws.ty("boolean|nil")); + } + + #[test] + fn test_return_overload_short_row_keeps_nil_in_missing_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer + ---@return_overload false + local function maybe(ok) + if ok then + return true, 1 + end + return false + end + + local cond ---@type boolean + status, value = maybe(cond) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("value"), ws.ty("integer|nil")); + } + + #[test] + fn test_return_overload_concrete_variadic_tail_keeps_unbounded_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ok boolean + ---@return_overload true, integer... + ---@return_overload false, string + local function wrap(ok) + if ok then + return true, 1, 2, 3, 4 + end + return false, "err" + end + + local cond ---@type boolean + status, first, second, third, fourth = wrap(cond) + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("false|true")); + assert_eq!(ws.expr_ty("first"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("second"), ws.ty("integer|nil")); + assert_eq!(ws.expr_ty("third"), ws.ty("integer|nil")); + assert_eq!(ws.expr_ty("fourth"), ws.ty("integer|nil")); + } + + #[test] + fn test_return_overload_partial_rows_preserve_return_docs_fallback() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return integer|string + ---@return_overload true, integer + local function partial(ok) + if ok then + return true, 1 + end + return false, "err" + end + + local cond ---@type boolean + status, value = partial(cond) + "#, + ); + + let status_ty = ws.expr_ty("status"); + let boolean_ty = ws.ty("boolean"); + assert!(ws.check_type(&status_ty, &boolean_ty)); + assert!(!status_ty.is_always_truthy()); + assert!(!status_ty.is_always_falsy()); + + let value_ty = ws.expr_ty("value"); + let integer_or_string_ty = ws.ty("integer|string"); + let integer_ty = ws.ty("integer"); + let string_ty = ws.ty("string"); + assert!(ws.check_type(&value_ty, &integer_or_string_ty)); + assert!(ws.check_type(&value_ty, &integer_ty)); + assert!(ws.check_type(&value_ty, &string_ty)); + } + + #[test] + fn test_return_overload_docs_merge_mixed_variadic_shapes() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return string... + ---@return_overload true, integer + local function maybe() + end + + status, value, extra = maybe() + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("string|true")); + assert_eq!(ws.expr_ty("value"), ws.ty("integer|string")); + assert_eq!(ws.expr_ty("extra"), ws.ty("nil|string")); + } + + #[test] + fn test_return_overload_fallback_docs_keep_trailing_nil_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return boolean + ---@return_overload true, integer + local function maybe_value() + end + + status, value = maybe_value() + "#, + ); + + assert_eq!(ws.expr_ty("status"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("value"), ws.ty("integer|nil")); + } +} diff --git a/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs b/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs index 0e006f038..0658c6473 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/flow_tree.rs @@ -1,12 +1,14 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; -use emmylua_parser::{LuaAstPtr, LuaExpr, LuaSyntaxId}; +use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaExpr, LuaSyntaxId}; +use rowan::TextSize; -use crate::{FlowId, FlowNode, LuaDeclId}; +use crate::{FlowAntecedent, FlowId, FlowNode, LuaDeclId}; #[derive(Debug)] pub struct FlowTree { decl_bind_expr_ref: HashMap>, + decl_multi_return_ref: HashMap>, flow_nodes: Vec, multiple_antecedents: Vec>, // labels: HashMap>, @@ -16,6 +18,7 @@ pub struct FlowTree { impl FlowTree { pub fn new( decl_bind_expr_ref: HashMap>, + decl_multi_return_ref: HashMap>, flow_nodes: Vec, multiple_antecedents: Vec>, // labels: HashMap>, @@ -23,6 +26,7 @@ impl FlowTree { ) -> Self { Self { decl_bind_expr_ref, + decl_multi_return_ref, flow_nodes, multiple_antecedents, bindings, @@ -46,4 +50,210 @@ impl FlowTree { pub fn get_decl_ref_expr(&self, decl_id: &LuaDeclId) -> Option> { self.decl_bind_expr_ref.get(decl_id).cloned() } + + pub fn has_decl_multi_return_refs(&self, decl_id: &LuaDeclId) -> bool { + self.decl_multi_return_ref.contains_key(decl_id) + } + + /// Chooses the search roots used to resolve correlated multi-return refs. + /// + /// If either declaration already has a multi-return ref reachable on the current + /// straight-line history, the caller can analyze the current flow directly and we + /// return `current_flow_id` as the only search root. + /// + /// Otherwise the current flow sits after a branch merge, so we walk backward to the + /// nearest multi-antecedent join and return each incoming branch flow separately. + /// This lets downstream correlation logic analyze branch-local histories without + /// mixing refs from different branches together. + pub fn get_decl_multi_return_search_roots( + &self, + discriminant_decl_id: &LuaDeclId, + target_decl_id: &LuaDeclId, + position: TextSize, + current_flow_id: FlowId, + ) -> Vec { + if self.has_decl_multi_return_ref_on_linear_history( + discriminant_decl_id, + position, + current_flow_id, + ) || self.has_decl_multi_return_ref_on_linear_history( + target_decl_id, + position, + current_flow_id, + ) { + vec![current_flow_id] + } else { + self.get_nearest_branch_antecedents(current_flow_id) + } + } + + pub fn get_decl_multi_return_ref_summary_at( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + ) -> (Vec, bool) { + let mut refs = Vec::new(); + let mut has_non_reference_origin = false; + let mut visited = HashSet::new(); + self.collect_decl_multi_return_refs_at( + decl_id, + position, + flow_id, + &mut visited, + &mut refs, + &mut has_non_reference_origin, + ); + (refs, has_non_reference_origin) + } + + fn collect_decl_multi_return_refs_at( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + visited: &mut HashSet, + refs: &mut Vec, + has_non_reference_origin: &mut bool, + ) { + if !visited.insert(flow_id) { + return; + } + + if let Some(at) = self.get_decl_multi_return_ref_on_flow(decl_id, position, flow_id) { + if let Some(reference) = &at.reference { + refs.push(reference.clone()); + } else { + *has_non_reference_origin = true; + } + return; + } + + let Some(flow_node) = self.get_flow_node(flow_id) else { + *has_non_reference_origin = true; + return; + }; + let Some(antecedent) = flow_node.antecedent.as_ref() else { + *has_non_reference_origin = true; + return; + }; + match antecedent { + FlowAntecedent::Single(next_flow_id) => { + self.collect_decl_multi_return_refs_at( + decl_id, + position, + *next_flow_id, + visited, + refs, + has_non_reference_origin, + ); + } + FlowAntecedent::Multiple(multi_id) => { + if let Some(multi_antecedents) = self.get_multi_antecedents(*multi_id) { + for &next_flow_id in multi_antecedents { + self.collect_decl_multi_return_refs_at( + decl_id, + position, + next_flow_id, + visited, + refs, + has_non_reference_origin, + ); + } + } else { + *has_non_reference_origin = true; + } + } + } + } + + fn get_decl_multi_return_ref_on_flow( + &self, + decl_id: &LuaDeclId, + position: TextSize, + flow_id: FlowId, + ) -> Option<&DeclMultiReturnRefAt> { + self.decl_multi_return_ref + .get(decl_id)? + .iter() + .rev() + .find(|entry| entry.position <= position && entry.flow_id == flow_id) + } + + /// Returns whether `decl_id` has a recorded multi-return ref on the linear backward history. + /// + /// "Linear history" means repeatedly following only `FlowAntecedent::Single` links from + /// `start_flow_id`. The search stops as soon as it reaches a merge (`Multiple`) or the start + /// of flow. In other words, this checks only the current straight-line history and does + /// not inspect alternate branch predecessors. + fn has_decl_multi_return_ref_on_linear_history( + &self, + decl_id: &LuaDeclId, + position: TextSize, + start_flow_id: FlowId, + ) -> bool { + let mut current_flow_id = start_flow_id; + let mut visited = HashSet::new(); + loop { + if !visited.insert(current_flow_id) { + return false; + } + + if self + .get_decl_multi_return_ref_on_flow(decl_id, position, current_flow_id) + .is_some() + { + return true; + } + + let Some(flow_node) = self.get_flow_node(current_flow_id) else { + return false; + }; + match flow_node.antecedent.as_ref() { + Some(FlowAntecedent::Single(next_flow_id)) => { + current_flow_id = *next_flow_id; + } + Some(FlowAntecedent::Multiple(_)) | None => return false, + } + } + } + + fn get_nearest_branch_antecedents(&self, start_flow_id: FlowId) -> Vec { + let mut current_flow_id = start_flow_id; + let mut visited = HashSet::new(); + loop { + if !visited.insert(current_flow_id) { + return vec![start_flow_id]; + } + + let Some(flow_node) = self.get_flow_node(current_flow_id) else { + return vec![start_flow_id]; + }; + match flow_node.antecedent.as_ref() { + Some(FlowAntecedent::Multiple(multi_id)) => { + return self + .get_multi_antecedents(*multi_id) + .map(|flows| flows.to_vec()) + .unwrap_or_else(|| vec![start_flow_id]); + } + Some(FlowAntecedent::Single(next_flow_id)) => { + current_flow_id = *next_flow_id; + } + None => return vec![start_flow_id], + } + } + } +} + +#[derive(Debug, Clone)] +pub struct DeclMultiReturnRef { + pub call_expr: LuaAstPtr, + pub return_index: usize, +} + +#[derive(Debug, Clone)] +pub struct DeclMultiReturnRefAt { + pub position: TextSize, + pub flow_id: FlowId, + pub reference: Option, } diff --git a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs index 95ef92cd4..2e54f0c81 100644 --- a/crates/emmylua_code_analysis/src/db_index/flow/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/flow/mod.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use crate::{FileId, LuaSignatureId}; use emmylua_parser::{LuaAstPtr, LuaDocOpType}; pub use flow_node::*; -pub use flow_tree::FlowTree; +pub use flow_tree::{DeclMultiReturnRef, DeclMultiReturnRefAt, FlowTree}; pub use signature_cast::LuaSignatureCast; use super::traits::LuaIndex; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs index 7c496430b..3c7fd68be 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs @@ -1,4 +1,5 @@ mod async_state; +mod return_rows; #[allow(clippy::module_inception)] mod signature; @@ -6,8 +7,8 @@ use std::collections::{HashMap, HashSet}; pub use async_state::AsyncState; pub use signature::{ - LuaDocParamInfo, LuaDocReturnInfo, LuaGenericParamInfo, LuaNoDiscard, LuaSignature, - LuaSignatureId, SignatureReturnStatus, + LuaDocParamInfo, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaGenericParamInfo, LuaNoDiscard, + LuaSignature, LuaSignatureId, SignatureReturnStatus, }; use crate::FileId; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs b/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs new file mode 100644 index 000000000..18e2217d6 --- /dev/null +++ b/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs @@ -0,0 +1,244 @@ +use std::sync::Arc; + +use crate::{ + LuaAliasCallKind, LuaAliasCallType, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaType, + VariadicType, db_index::union_type_shallow, +}; + +pub(super) fn get_return_type( + return_docs: &[LuaDocReturnInfo], + return_overloads: &[LuaDocReturnOverloadInfo], +) -> LuaType { + let return_docs_type = row_to_return_type( + return_docs + .iter() + .map(|info| info.type_ref.clone()) + .collect(), + ); + if return_overloads.is_empty() { + return return_docs_type; + } + + let overload_return_type = rows_to_return_type( + &return_overloads + .iter() + .map(|overload| overload.type_refs.as_slice()) + .collect::>(), + ); + if return_docs.is_empty() { + overload_return_type + } else { + merge_return_type(overload_return_type, return_docs_type) + } +} + +pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { + get_overload_row_slot_if_present(row, idx).unwrap_or(LuaType::Nil) +} + +pub(crate) fn row_to_return_type(mut row: Vec) -> LuaType { + match row.len() { + 0 => LuaType::Nil, + 1 => row.pop().unwrap_or(LuaType::Nil), + _ => LuaType::Variadic(VariadicType::Multi(row).into()), + } +} + +pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec { + match return_type { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Multi(types) => types.clone(), + VariadicType::Base(_) => vec![LuaType::Variadic(variadic)], + }, + typ => vec![typ], + } +} + +fn rows_to_return_type(rows: &[&[LuaType]]) -> LuaType { + let Some(base_max_len) = rows.iter().map(|row| row.len()).max() else { + return LuaType::Nil; + }; + if base_max_len == 0 { + return LuaType::Nil; + } + + let (has_variadic_tail, has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = + rows.iter().fold( + (false, false, false), + |(has_var, has_unbounded, has_tpl_unbounded), row| { + let Some(last) = row.last() else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + let LuaType::Variadic(variadic) = last else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + + let has_unbounded_row = variadic.get_max_len().is_none(); + ( + true, + has_unbounded || has_unbounded_row, + has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), + ) + }, + ); + let max_len = if has_variadic_tail { + base_max_len + 1 + } else { + base_max_len + }; + let fill_missing_with_nil = |idx: usize| idx < base_max_len || has_unbounded_variadic_tail; + + let mut types = Vec::with_capacity(max_len); + for idx in 0..max_len { + let slot_types = rows + .iter() + .filter_map(|row| { + get_overload_row_slot_if_present(row, idx) + .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)) + }) + .collect(); + types.push(LuaType::from_vec(slot_types)); + } + if has_unbounded_variadic_tail + && !has_tpl_unbounded_variadic_tail + && let Some(last) = types.last_mut() + && !matches!(last, LuaType::Variadic(_)) + { + *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); + } + + row_to_return_type(types) +} + +fn merge_return_rows(left_row: &[LuaType], right_row: &[LuaType]) -> LuaType { + let base_max_len = left_row.len().max(right_row.len()); + let (has_variadic_tail, has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = + [left_row, right_row].iter().fold( + (false, false, false), + |(has_var, has_unbounded, has_tpl_unbounded), row| { + let Some(last) = row.last() else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + let LuaType::Variadic(variadic) = last else { + return (has_var, has_unbounded, has_tpl_unbounded); + }; + + let has_unbounded_row = variadic.get_max_len().is_none(); + ( + true, + has_unbounded || has_unbounded_row, + has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), + ) + }, + ); + let max_len = if has_variadic_tail { + base_max_len + 1 + } else { + base_max_len + }; + let fill_missing_with_nil = |idx: usize| idx < base_max_len || has_unbounded_variadic_tail; + + let mut types = Vec::with_capacity(max_len); + for idx in 0..max_len { + let left_type = get_overload_row_slot_if_present(left_row, idx) + .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)); + let right_type = get_overload_row_slot_if_present(right_row, idx) + .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)); + + let merged_type = match (left_type, right_type) { + (Some(left), Some(right)) => union_type_shallow(left, right), + (Some(left), None) | (None, Some(left)) => left, + (None, None) => continue, + }; + types.push(merged_type); + } + if has_unbounded_variadic_tail + && !has_tpl_unbounded_variadic_tail + && let Some(last) = types.last_mut() + && !matches!(last, LuaType::Variadic(_)) + { + *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); + } + + row_to_return_type(types) +} + +fn merge_return_type(left: LuaType, right: LuaType) -> LuaType { + if left == LuaType::Unknown { + return right; + } + if right == LuaType::Unknown { + return left; + } + + match (&left, &right) { + (LuaType::Variadic(_), _) | (_, LuaType::Variadic(_)) => { + let left_row = return_type_to_row(left); + let right_row = return_type_to_row(right); + merge_return_rows(&left_row, &right_row) + } + _ => union_type_shallow(left, right), + } +} + +fn overload_row_tpl_slot( + call_kind: LuaAliasCallKind, + variadic: &Arc, + index: i64, +) -> LuaType { + LuaType::Call( + LuaAliasCallType::new( + call_kind, + vec![ + LuaType::Variadic(variadic.clone()), + LuaType::IntegerConst(index), + ], + ) + .into(), + ) +} + +fn get_overload_row_slot_if_present(row: &[LuaType], idx: usize) -> Option { + let row_len = row.len(); + if row_len == 0 { + return None; + } + + if idx + 1 < row_len { + return Some(row[idx].clone()); + } + + let last_idx = row_len - 1; + let last_ty = &row[last_idx]; + let offset = idx - last_idx; + if let LuaType::Variadic(variadic) = last_ty { + if let Some(slot) = variadic.get_type(offset).cloned() { + if slot.contain_tpl() { + if offset > 0 && matches!(variadic.as_ref(), VariadicType::Base(_)) { + return Some(overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )); + } + + return Some(overload_row_tpl_slot( + LuaAliasCallKind::Index, + variadic, + offset as i64, + )); + } + return Some(slot); + } + + Some(overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )) + } else if offset == 0 { + Some(last_ty.clone()) + } else { + None + } +} diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index 36aa471f2..e673caa9f 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -6,12 +6,13 @@ use std::{collections::HashMap, sync::Arc}; use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaDocFuncType}; use rowan::TextSize; +use super::return_rows; use crate::db_index::signature::async_state::AsyncState; use crate::{ FileId, db_index::{LuaFunctionType, LuaType}, }; -use crate::{LuaAttributeUse, SemanticModel, VariadicType, first_param_may_not_self}; +use crate::{LuaAttributeUse, SemanticModel, first_param_may_not_self}; #[derive(Debug)] pub struct LuaSignature { @@ -20,6 +21,7 @@ pub struct LuaSignature { pub param_docs: HashMap, pub params: Vec, pub return_docs: Vec, + pub return_overloads: Vec, pub resolve_return: SignatureReturnStatus, pub is_colon_define: bool, pub async_state: AsyncState, @@ -47,6 +49,7 @@ impl LuaSignature { param_docs: HashMap::new(), params: Vec::new(), return_docs: Vec::new(), + return_overloads: Vec::new(), resolve_return: SignatureReturnStatus::UnResolve, is_colon_define: false, async_state: AsyncState::None, @@ -111,19 +114,19 @@ impl LuaSignature { } pub fn get_return_type(&self) -> LuaType { - match self.return_docs.len() { - 0 => LuaType::Nil, - 1 => self.return_docs[0].type_ref.clone(), - _ => LuaType::Variadic( - VariadicType::Multi( - self.return_docs - .iter() - .map(|info| info.type_ref.clone()) - .collect(), - ) - .into(), - ), - } + return_rows::get_return_type(&self.return_docs, &self.return_overloads) + } + + pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { + return_rows::get_overload_row_slot(row, idx) + } + + pub(crate) fn row_to_return_type(row: Vec) -> LuaType { + return_rows::row_to_return_type(row) + } + + pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec { + return_rows::return_type_to_row(return_type) } pub fn is_method(&self, semantic_model: &SemanticModel, owner_type: Option<&LuaType>) -> bool { @@ -209,6 +212,12 @@ pub struct LuaDocReturnInfo { pub attributes: Option>, } +#[derive(Debug, Clone)] +pub struct LuaDocReturnOverloadInfo { + pub type_refs: Vec, + pub description: Option, +} + #[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] pub struct LuaSignatureId { file_id: FileId, diff --git a/crates/emmylua_code_analysis/src/db_index/type/mod.rs b/crates/emmylua_code_analysis/src/db_index/type/mod.rs index cf088895e..84ee0507b 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/mod.rs @@ -14,6 +14,7 @@ pub use humanize_type::{RenderLevel, TypeHumanizer, format_union_type, humanize_ use std::collections::{HashMap, HashSet}; pub use type_decl::{LuaDeclLocation, LuaDeclTypeKind, LuaTypeDecl, LuaTypeDeclId, LuaTypeFlag}; pub use type_ops::TypeOps; +pub(crate) use type_ops::union_type_shallow; pub use type_owner::{LuaTypeCache, LuaTypeOwner}; pub use type_visit_trait::TypeVisitTrait; pub use types::*; diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs b/crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs index 812c9e96f..955f4e8a4 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_ops/mod.rs @@ -5,6 +5,7 @@ mod union_type; use super::LuaType; use crate::DbIndex; +pub(crate) use union_type::union_type_shallow; #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum TypeOps { diff --git a/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs b/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs index ca370d94e..b1a2f6458 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/type_ops/union_type.rs @@ -3,9 +3,19 @@ use std::ops::Deref; use crate::{DbIndex, LuaType, LuaUnionType, get_real_type}; pub fn union_type(db: &DbIndex, source: LuaType, target: LuaType) -> LuaType { - let real_type = get_real_type(db, &source).unwrap_or(&source); + let match_source = get_real_type(db, &source) + .cloned() + .unwrap_or_else(|| source.clone()); + union_type_impl(&match_source, source, target) +} + +pub(crate) fn union_type_shallow(source: LuaType, target: LuaType) -> LuaType { + let match_source = source.clone(); + union_type_impl(&match_source, source, target) +} - match (&real_type, &target) { +fn union_type_impl(match_source: &LuaType, source: LuaType, target: LuaType) -> LuaType { + match (match_source, &target) { // ANY | T = ANY (LuaType::Any, _) => LuaType::Any, (_, LuaType::Any) => LuaType::Any, @@ -27,11 +37,14 @@ pub fn union_type(db: &DbIndex, source: LuaType, target: LuaType) -> LuaType { (LuaType::String, LuaType::StringConst(_) | LuaType::DocStringConst(_)) => LuaType::String, (LuaType::StringConst(_) | LuaType::DocStringConst(_), LuaType::String) => LuaType::String, // boolean | boolean const - (LuaType::Boolean, LuaType::BooleanConst(_)) => LuaType::Boolean, - (LuaType::BooleanConst(_), LuaType::Boolean) => LuaType::Boolean, - (LuaType::BooleanConst(left), LuaType::BooleanConst(right)) => { + (LuaType::Boolean, right) if right.is_boolean() => LuaType::Boolean, + (left, LuaType::Boolean) if left.is_boolean() => LuaType::Boolean, + ( + LuaType::BooleanConst(left) | LuaType::DocBooleanConst(left), + LuaType::BooleanConst(right) | LuaType::DocBooleanConst(right), + ) => { if left == right { - LuaType::BooleanConst(*left) + source.clone() } else { LuaType::Boolean } @@ -103,7 +116,7 @@ pub fn union_type(db: &DbIndex, source: LuaType, target: LuaType) -> LuaType { } // same type - (left, right) if *left == right => source.clone(), + (left, right) if *left == *right => source.clone(), _ => LuaType::from_vec(vec![source, target]), } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types.rs b/crates/emmylua_code_analysis/src/db_index/type/types.rs index 43af62efa..7487ea7d8 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types.rs @@ -421,6 +421,51 @@ impl LuaType { matches!(self, LuaType::Variadic(_)) } + pub fn contain_multi_return(&self) -> bool { + match self { + LuaType::Variadic(_) => true, + LuaType::Union(union) => union.into_vec().iter().any(LuaType::contain_multi_return), + _ => false, + } + } + + pub fn get_result_slot_type(&self, idx: usize) -> Option { + match self { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Base(base) => Some(base.clone()), + VariadicType::Multi(types) => { + let last_idx = types.len().checked_sub(1)?; + if idx < last_idx { + return types[idx].get_result_slot_type(0); + } + + let last_type = types.get(last_idx)?; + let offset = idx - last_idx; + last_type.get_result_slot_type(offset) + } + }, + LuaType::Union(union) => { + let slot_types = union + .into_vec() + .into_iter() + .map(|ty| ty.get_result_slot_type(idx)) + .collect::>(); + if !slot_types.iter().any(|ty| ty.is_some()) { + return None; + } + + Some(LuaType::from_vec( + slot_types + .into_iter() + .map(|ty| ty.unwrap_or(LuaType::Nil)) + .collect(), + )) + } + _ if idx == 0 => Some(self.clone()), + _ => None, + } + } + pub fn is_global(&self) -> bool { matches!(self, LuaType::Global) } @@ -508,6 +553,25 @@ impl LuaType { } } +#[cfg(test)] +mod tests { + use super::{LuaType, VariadicType}; + + #[test] + fn test_union_with_variadic_uses_result_slot_extraction() { + let variadic = LuaType::Variadic(VariadicType::Multi(vec![LuaType::String]).into()); + let optional_variadic = LuaType::from_vec(vec![variadic.clone(), LuaType::Nil]); + + assert_eq!(variadic.get_result_slot_type(0), Some(LuaType::String)); + assert!(!optional_variadic.is_multi_return()); + assert!(optional_variadic.contain_multi_return()); + assert_eq!( + optional_variadic.get_result_slot_type(0), + Some(LuaType::from_vec(vec![LuaType::String, LuaType::Nil])) + ); + } +} + impl TypeVisitTrait for LuaType { fn visit_type(&self, f: &mut F) where diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs index 53b038b08..d20a1d248 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs @@ -1,8 +1,13 @@ use std::collections::HashSet; -use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaDocTagParam, LuaDocTagReturn, LuaStat}; +use emmylua_parser::{ + LuaAstNode, LuaClosureExpr, LuaDocTagParam, LuaDocTagReturn, LuaDocTagReturnOverload, LuaStat, +}; -use crate::{DiagnosticCode, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel}; +use crate::{ + DiagnosticCode, LuaSemanticDeclId, LuaSignatureId, LuaType, SemanticDeclLevel, SemanticModel, + SignatureReturnStatus, +}; use super::{Checker, DiagnosticContext, get_closure_expr_comment, get_return_stats}; @@ -81,10 +86,20 @@ fn check_doc( }) .collect(); - let doc_return_len: usize = comment - .children::() - .map(|return_doc| return_doc.get_types().count()) - .sum(); + let doc_return_len = + get_doc_return_max_len(semantic_model, closure_expr).unwrap_or_else(|| { + let doc_return_len: usize = comment + .children::() + .map(|return_doc| return_doc.get_types().count()) + .sum(); + let doc_return_overload_max_len = comment + .children::() + .map(|return_doc| return_doc.get_types().count()) + .max() + .unwrap_or(0); + + Some(doc_return_len.max(doc_return_overload_max_len)) + }); check_params( context, @@ -149,7 +164,7 @@ fn check_returns( context: &mut DiagnosticContext, semantic_model: &SemanticModel, closure_expr: &LuaClosureExpr, - doc_return_len: usize, + doc_return_len: Option, code: DiagnosticCode, is_global: bool, function_name: &str, @@ -169,7 +184,9 @@ fn check_returns( return_stat_len += expr_return_count; - if return_stat_len > doc_return_len { + if let Some(doc_return_len) = doc_return_len + && return_stat_len > doc_return_len + { let message = if is_global { t!( "Missing @return annotation at index `%{index}` in global function `%{function_name}`.", @@ -190,3 +207,25 @@ fn check_returns( Some(()) } + +fn get_doc_return_max_len( + semantic_model: &SemanticModel, + closure_expr: &LuaClosureExpr, +) -> Option> { + let signature_id = LuaSignatureId::from_closure(semantic_model.get_file_id(), closure_expr); + let signature = semantic_model + .get_db() + .get_signature_index() + .get(&signature_id)?; + if signature.resolve_return != SignatureReturnStatus::DocResolve { + return None; + } + let return_type = signature.get_return_type(); + + Some(match return_type { + LuaType::Variadic(variadic) => variadic.get_max_len(), + LuaType::Any | LuaType::Unknown => Some(1), + LuaType::Nil => Some(0), + _ => Some(1), + }) +} diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs index b640c2b8c..c9b8388a8 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs @@ -94,6 +94,51 @@ mod tests { )); } + #[test] + fn test_return_overload() { + let mut ws = VirtualWorkspace::new(); + ws.enable_full_diagnostic(); + + assert!(ws.check_code_for( + DiagnosticCode::IncompleteSignatureDoc, + r#" + ---@return_overload true, integer + ---@return_overload false, string + local function f() + return true, 1 + end + "# + )); + + assert!(!ws.check_code_for( + DiagnosticCode::IncompleteSignatureDoc, + r#" + ---@return_overload true, integer + ---@return_overload false, string + local function f() + return true, 1, "extra" + end + "# + )); + } + + #[test] + fn test_variadic_return_overload_does_not_trigger_incomplete_signature_doc() { + let mut ws = VirtualWorkspace::new(); + ws.enable_full_diagnostic(); + + assert!(ws.check_code_for( + DiagnosticCode::IncompleteSignatureDoc, + r#" + ---@return_overload true, integer... + ---@return_overload false, string + local function f() + return true, 1, 2, 3, 4 + end + "# + )); + } + #[test] fn test_global() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 5c2c2889c..d0a6b24c9 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -4,6 +4,7 @@ use emmylua_parser::{LuaAstNode, LuaDocTypeList}; use emmylua_parser::{LuaCallExpr, LuaExpr}; use internment::ArcIntern; +use crate::semantic::infer::infer_expr_list_types; use crate::{ DocTypeInferContext, FileId, GenericTpl, GenericTplId, LuaFunctionType, LuaGenericType, TypeVisitTrait, @@ -15,15 +16,18 @@ use crate::{ instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ - multi_param_tpl_pattern_match_multi_return, tpl_pattern_match, - variadic_tpl_pattern_match, + multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, + tpl_pattern_match, variadic_tpl_pattern_match, }, }, infer::InferFailReason, infer_expr, }, }; -use crate::{LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl}; +use crate::{ + LuaMemberOwner, LuaSemanticDeclId, SemanticDeclLevel, infer_node_semantic_decl, + tpl_pattern_match_args, +}; use super::TypeSubstitutor; @@ -115,6 +119,99 @@ fn apply_call_generic_type_list( } } +pub fn as_doc_function_type( + db: &DbIndex, + callable_type: &LuaType, +) -> Result>, InferFailReason> { + Ok(match callable_type { + LuaType::DocFunction(doc_func) => Some(doc_func.clone()), + LuaType::Signature(sig_id) => Some( + db.get_signature_index() + .get(sig_id) + .ok_or(InferFailReason::None)? + .to_doc_func_type(), + ), + _ => None, + }) +} + +fn infer_return_from_callable( + db: &DbIndex, + callable: &Arc, + substitutor: &TypeSubstitutor, +) -> LuaType { + let instantiated = instantiate_doc_function(db, callable, substitutor); + match instantiated { + LuaType::DocFunction(func) => func.get_ret().clone(), + _ => callable.get_ret().clone(), + } +} + +pub fn infer_callable_return_from_remaining_args( + context: &mut TplContext, + callable_type: &LuaType, + arg_exprs: &[LuaExpr], +) -> Result, InferFailReason> { + if arg_exprs.is_empty() { + return Ok(None); + } + + let Some(callable) = as_doc_function_type(context.db, callable_type)? else { + return Ok(None); + }; + + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + if callable_tpls.is_empty() { + return Ok(Some(callable.get_ret().clone())); + } + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor); + + let call_arg_types = + match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { + Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), + Err(_) => return Ok(Some(fallback_return)), + }; + if call_arg_types.is_empty() { + return Ok(None); + } + + let callable_param_types = callable + .get_params() + .iter() + .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) + .collect::>(); + + let mut callable_context = TplContext { + db: context.db, + cache: context.cache, + substitutor: &mut callable_substitutor, + call_expr: context.call_expr.clone(), + }; + if tpl_pattern_match_args( + &mut callable_context, + &callable_param_types, + &call_arg_types, + ) + .is_err() + { + return Ok(Some(fallback_return)); + } + + Ok(Some(infer_return_from_callable( + context.db, + &callable, + &callable_substitutor, + ))) +} + fn infer_generic_types_from_call( db: &DbIndex, context: &mut TplContext, @@ -166,6 +263,15 @@ fn infer_generic_types_from_call( Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), }; + + if let Some(return_pattern) = + as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) + && let Some(inferred_return_type) = + infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? + { + return_type_pattern_match_target_type(context, &return_pattern, &inferred_return_type)?; + } + match (func_param_type, &arg_type) { (LuaType::Variadic(variadic), _) => { let mut arg_types = vec![]; @@ -268,18 +374,8 @@ fn check_expr_can_later_infer( func_param_type: &LuaType, call_arg_expr: &LuaExpr, ) -> Result { - let doc_function = match func_param_type { - LuaType::DocFunction(doc_func) => doc_func.clone(), - LuaType::Signature(sig_id) => { - let sig = context - .db - .get_signature_index() - .get(sig_id) - .ok_or(InferFailReason::None)?; - - sig.to_doc_func_type() - } - _ => return Ok(false), + let Some(doc_function) = as_doc_function_type(context.db, func_param_type)? else { + return Ok(false); }; if let LuaExpr::ClosureExpr(_) = call_arg_expr { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 754974211..88e695057 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -688,7 +688,7 @@ fn param_type_list_pattern_match_type_list( Ok(()) } -fn return_type_pattern_match_target_type( +pub(crate) fn return_type_pattern_match_target_type( context: &mut TplContext, source: &LuaType, target: &LuaType, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index 00c4996e4..3fd266d29 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -641,15 +641,12 @@ pub(crate) fn unwrapp_return_type( return Ok(return_type); } - LuaType::Variadic(variadic) => { + ty if ty.contain_multi_return() => { if is_last_call_expr(&call_expr) { - return Ok(return_type); + return Ok(ty.clone()); } - return match variadic.get_type(0) { - Some(ty) => Ok(ty.clone()), - None => Ok(LuaType::Nil), - }; + return Ok(ty.get_result_slot_type(0).unwrap_or(LuaType::Nil)); } LuaType::SelfInfer => { if let Some(self_type) = infer_self_type(db, cache, &call_expr) { @@ -815,4 +812,33 @@ mod tests { assert!(!matches!(second, Err(InferFailReason::RecursiveInfer))); } + + #[test] + fn test_higher_order_call_with_unresolved_remaining_arg_should_not_hard_fail() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) end + + ---@generic U: string + ---@param x U + ---@return U + local function id(x) end + + ---@class Box + ---@field value integer + ---@type Box + local box + + ok, payload = wrap(id, box.missing) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index 06354b8b5..e30b8477c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -221,27 +221,31 @@ where } let expr_type = infer(db, cache, expr.clone())?; + if let Some(var_count) = var_count + && expr_type.contain_multi_return() + { + if idx < var_count { + for i in idx..var_count { + if let Some(typ) = expr_type.get_result_slot_type(i - idx) { + value_types.push((typ, expr.get_range())); + } else { + break; + } + } + } + + break; + } + match expr_type { LuaType::Variadic(variadic) => { - if let Some(var_count) = var_count { - if idx < var_count { - for i in idx..var_count { - if let Some(typ) = variadic.get_type(i - idx) { - value_types.push((typ.clone(), expr.get_range())); - } else { - break; - } - } + match variadic.deref() { + VariadicType::Base(base) => { + value_types.push((base.clone(), expr.get_range())); } - } else { - match variadic.deref() { - VariadicType::Base(base) => { - value_types.push((base.clone(), expr.get_range())); - } - VariadicType::Multi(vecs) => { - for typ in vecs { - value_types.push((typ.clone(), expr.get_range())); - } + VariadicType::Multi(vecs) => { + for typ in vecs { + value_types.push((typ.clone(), expr.get_range())); } } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index 1653f5112..f8052cadf 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -1,5 +1,5 @@ use emmylua_parser::{ - BinaryOperator, LuaAstNode, LuaBinaryExpr, LuaCallExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr, + BinaryOperator, LuaAstNode, LuaBinaryExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr, LuaLiteralToken, UnaryOperator, }; @@ -11,7 +11,10 @@ use crate::{ infer_index::infer_member_by_member_key, narrow::{ ResultTypeOrContinue, - condition_flow::{InferConditionFlow, call_flow::get_type_at_call_expr}, + condition_flow::{ + InferConditionFlow, always_literal_equal, call_flow::get_type_at_call_expr, + correlated_flow::narrow_var_from_return_overload_condition, + }, get_single_antecedent, get_type_at_flow::get_type_at_flow, get_var_ref_type, narrow_down_type, @@ -102,7 +105,7 @@ fn try_get_at_eq_or_neq_expr( right_expr: LuaExpr, condition_flow: InferConditionFlow, ) -> Result { - let mut result_type = maybe_type_guard_binary( + if let ResultTypeOrContinue::Result(result_type) = maybe_type_guard_binary( db, tree, cache, @@ -112,12 +115,11 @@ fn try_get_at_eq_or_neq_expr( left_expr.clone(), right_expr.clone(), condition_flow, - )?; - if let ResultTypeOrContinue::Result(result_type) = result_type { + )? { return Ok(ResultTypeOrContinue::Result(result_type)); } - result_type = maybe_field_literal_eq_narrow( + if let ResultTypeOrContinue::Result(result_type) = maybe_field_literal_eq_narrow( db, tree, cache, @@ -127,12 +129,22 @@ fn try_get_at_eq_or_neq_expr( left_expr.clone(), right_expr.clone(), condition_flow, - )?; - - if let ResultTypeOrContinue::Result(result_type) = result_type { + )? { return Ok(ResultTypeOrContinue::Result(result_type)); } + let (left_expr, right_expr) = if !matches!( + left_expr, + LuaExpr::NameExpr(_) | LuaExpr::CallExpr(_) | LuaExpr::IndexExpr(_) | LuaExpr::UnaryExpr(_) + ) && matches!( + right_expr, + LuaExpr::NameExpr(_) | LuaExpr::CallExpr(_) | LuaExpr::IndexExpr(_) | LuaExpr::UnaryExpr(_) + ) { + (right_expr, left_expr) + } else { + (left_expr, right_expr) + }; + maybe_var_eq_narrow( db, tree, @@ -224,78 +236,47 @@ fn maybe_type_guard_binary( right_expr: LuaExpr, condition_flow: InferConditionFlow, ) -> Result { - let mut type_guard_expr: Option = None; - let mut literal_string = String::new(); - if let LuaExpr::CallExpr(call_expr) = left_expr { - if call_expr.is_type() { - type_guard_expr = Some(call_expr); - if let LuaExpr::LiteralExpr(literal_expr) = right_expr { - match literal_expr.get_literal() { - Some(LuaLiteralToken::String(s)) => { - literal_string = s.get_value(); - } - _ => return Ok(ResultTypeOrContinue::Continue), - } - } - } - } else if let LuaExpr::CallExpr(call_expr) = right_expr { - if call_expr.is_type() { - type_guard_expr = Some(call_expr); - if let LuaExpr::LiteralExpr(literal_expr) = left_expr { - match literal_expr.get_literal() { - Some(LuaLiteralToken::String(s)) => { - literal_string = s.get_value(); - } - _ => return Ok(ResultTypeOrContinue::Continue), - } - } + let (candidate_expr, literal_expr) = match (left_expr, right_expr) { + // If either side is a literal expression and the other side is a type guard call expression + // (or ref), we can narrow it + (candidate_expr, LuaExpr::LiteralExpr(literal_expr)) + | (LuaExpr::LiteralExpr(literal_expr), candidate_expr) => { + (Some(candidate_expr), Some(literal_expr)) } + _ => (None, None), + }; + + let (Some(candidate_expr), Some(LuaLiteralToken::String(literal_string))) = + (candidate_expr, literal_expr.and_then(|e| e.get_literal())) + else { + return Ok(ResultTypeOrContinue::Continue); + }; + + let candidate_expr = match candidate_expr { // may ref a type value - } else if let LuaExpr::NameExpr(name_expr) = left_expr - && let LuaExpr::LiteralExpr(literal_expr) = right_expr - { - let Some(decl_id) = db + LuaExpr::NameExpr(name_expr) => db .get_reference_index() .get_var_reference_decl(&cache.get_file_id(), name_expr.get_range()) - else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let Some(expr) = expr_ptr.to_node(root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - if let LuaExpr::CallExpr(call_expr) = expr { - if call_expr.is_type() { - type_guard_expr = Some(call_expr); - match literal_expr.get_literal() { - Some(LuaLiteralToken::String(s)) => { - literal_string = s.get_value(); - } - _ => return Ok(ResultTypeOrContinue::Continue), - } - } - } else { - return Ok(ResultTypeOrContinue::Continue); - } - } + .and_then(|decl_id| tree.get_decl_ref_expr(&decl_id)) + .and_then(|expr_ptr| expr_ptr.to_node(root)), + expr => Some(expr), + }; - let Some(type_guard_expr) = type_guard_expr else { + let Some(type_guard_expr) = candidate_expr.and_then(|expr| match expr { + LuaExpr::CallExpr(call_expr) if call_expr.is_type() => Some(call_expr), + _ => None, + }) else { return Ok(ResultTypeOrContinue::Continue); }; - if literal_string.is_empty() { - return Ok(ResultTypeOrContinue::Continue); - } - let Some(arg_list) = type_guard_expr.get_args_list() else { + let Some(narrow) = type_call_name_to_type(&literal_string.get_value()) else { return Ok(ResultTypeOrContinue::Continue); }; - let Some(arg) = arg_list.get_args().next() else { + let Some(arg) = type_guard_expr + .get_args_list() + .and_then(|arg_list| arg_list.get_args().next()) + else { return Ok(ResultTypeOrContinue::Continue); }; @@ -304,14 +285,39 @@ fn maybe_type_guard_binary( return Ok(ResultTypeOrContinue::Continue); }; - if maybe_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); - } + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_type = + get_type_at_flow(db, tree, cache, root, &maybe_var_ref_id, antecedent_flow_id)?; + let narrowed_discriminant_type = match condition_flow { + InferConditionFlow::TrueCondition => { + narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow) + } + InferConditionFlow::FalseCondition => TypeOps::Remove.apply(db, &antecedent_type, &narrow), + }; - let anatecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, anatecedent_flow_id)?; + if maybe_var_ref_id == *var_ref_id { + Ok(ResultTypeOrContinue::Result(narrowed_discriminant_type)) + } else { + let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { + return Ok(ResultTypeOrContinue::Continue); + }; + narrow_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + discriminant_decl_id, + type_guard_expr.get_position(), + &narrowed_discriminant_type, + ) + } +} - let narrow = match literal_string.as_str() { +/// Maps the string result of Lua's builtin `type()` call to the corresponding `LuaType`. +fn type_call_name_to_type(literal_string: &str) -> Option { + Some(match literal_string { "number" => LuaType::Number, "string" => LuaType::String, "boolean" => LuaType::Boolean, @@ -320,20 +326,30 @@ fn maybe_type_guard_binary( "thread" => LuaType::Thread, "userdata" => LuaType::Userdata, "nil" => LuaType::Nil, - _ => { - // If the type is not recognized, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); - } - }; + _ => return None, + }) +} - let result_type = match condition_flow { +fn narrow_eq_condition( + db: &DbIndex, + antecedent_type: LuaType, + right_expr_type: LuaType, + condition_flow: InferConditionFlow, +) -> LuaType { + match condition_flow { InferConditionFlow::TrueCondition => { - narrow_down_type(db, antecedent_type.clone(), narrow.clone(), None).unwrap_or(narrow) - } - InferConditionFlow::FalseCondition => TypeOps::Remove.apply(db, &antecedent_type, &narrow), - }; + let left_maybe_type = TypeOps::Intersect.apply(db, &antecedent_type, &right_expr_type); - Ok(ResultTypeOrContinue::Result(result_type)) + if left_maybe_type.is_never() { + antecedent_type + } else { + left_maybe_type + } + } + InferConditionFlow::FalseCondition => { + TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type) + } + } } #[allow(clippy::too_many_arguments)] @@ -357,15 +373,32 @@ fn maybe_var_eq_narrow( return Ok(ResultTypeOrContinue::Continue); }; + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let right_expr_type = infer_expr(db, cache, right_expr)?; + if maybe_ref_id != *var_ref_id { - // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + let Some(discriminant_decl_id) = maybe_ref_id.get_decl_id_ref() else { + return Ok(ResultTypeOrContinue::Continue); + }; + let antecedent_type = + get_type_at_flow(db, tree, cache, root, &maybe_ref_id, antecedent_flow_id)?; + let narrowed_discriminant_type = + narrow_eq_condition(db, antecedent_type, right_expr_type, condition_flow); + return narrow_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + discriminant_decl_id, + left_name_expr.get_position(), + &narrowed_discriminant_type, + ); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; let left_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let right_expr_type = infer_expr(db, cache, right_expr)?; let result_type = match condition_flow { InferConditionFlow::TrueCondition => { @@ -373,14 +406,7 @@ fn maybe_var_eq_narrow( if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil) } else { - let left_maybe_type = - TypeOps::Intersect.apply(db, &left_type, &right_expr_type); - - if left_maybe_type.is_never() { - left_type - } else { - left_maybe_type - } + narrow_eq_condition(db, left_type, right_expr_type, condition_flow) } } InferConditionFlow::FalseCondition => { @@ -564,7 +590,7 @@ fn maybe_field_literal_eq_narrow( Ok(member_type) => member_type, Err(_) => continue, // If we cannot infer the member type, skip this type }; - if const_type_eq(&member_type, &right_type) { + if always_literal_equal(&member_type, &right_type) { // If the right type matches the member type, we can narrow it opt_result = Some(i); } @@ -586,23 +612,3 @@ fn maybe_field_literal_eq_narrow( Ok(ResultTypeOrContinue::Continue) } - -fn const_type_eq(left_type: &LuaType, right_type: &LuaType) -> bool { - if left_type == right_type { - return true; - } - - match (left_type, right_type) { - ( - LuaType::StringConst(l) | LuaType::DocStringConst(l), - LuaType::StringConst(r) | LuaType::DocStringConst(r), - ) => l == r, - (LuaType::FloatConst(l), LuaType::FloatConst(r)) => l == r, - (LuaType::BooleanConst(l), LuaType::BooleanConst(r)) => l == r, - ( - LuaType::IntegerConst(l) | LuaType::DocIntegerConst(l), - LuaType::IntegerConst(r) | LuaType::DocIntegerConst(r), - ) => l == r, - _ => false, - } -} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs new file mode 100644 index 000000000..8d0ddaac1 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -0,0 +1,313 @@ +use std::collections::HashSet; + +use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk}; + +use crate::{ + DbIndex, FlowId, FlowNode, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, + LuaInferCache, LuaType, TypeOps, infer_expr, instantiate_func_generic, + semantic::infer::{ + VarRefId, + narrow::{ResultTypeOrContinue, get_single_antecedent, get_type_at_flow::get_type_at_flow}, + }, +}; + +#[allow(clippy::too_many_arguments)] +pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return_overload_condition( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + flow_node: &FlowNode, + discriminant_decl_id: LuaDeclId, + condition_position: rowan::TextSize, + narrowed_discriminant_type: &LuaType, +) -> Result { + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(ResultTypeOrContinue::Continue); + }; + if !tree.has_decl_multi_return_refs(&discriminant_decl_id) + || !tree.has_decl_multi_return_refs(&target_decl_id) + { + return Ok(ResultTypeOrContinue::Continue); + } + + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let search_root_flow_ids = tree.get_decl_multi_return_search_roots( + &discriminant_decl_id, + &target_decl_id, + condition_position, + antecedent_flow_id, + ); + let mut matching_target_types = Vec::new(); + let mut uncorrelated_target_types = Vec::new(); + for search_root_flow_id in search_root_flow_ids { + let (root_matching_target_types, root_uncorrelated_target_type) = + collect_correlated_types_from_search_root( + db, + tree, + cache, + root, + var_ref_id, + discriminant_decl_id, + target_decl_id, + condition_position, + search_root_flow_id, + narrowed_discriminant_type, + )?; + matching_target_types.extend(root_matching_target_types); + if let Some(root_uncorrelated_target_type) = root_uncorrelated_target_type { + uncorrelated_target_types.push(root_uncorrelated_target_type); + } + } + + if matching_target_types.is_empty() { + return Ok(ResultTypeOrContinue::Continue); + } + + let matching_target_type = LuaType::from_vec(matching_target_types); + let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + let narrowed_correlated_type = + TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); + if narrowed_correlated_type.is_never() { + return Ok(ResultTypeOrContinue::Continue); + } + + if uncorrelated_target_types.is_empty() { + return Ok(if narrowed_correlated_type == antecedent_type { + ResultTypeOrContinue::Continue + } else { + ResultTypeOrContinue::Result(narrowed_correlated_type) + }); + } + + let uncorrelated_target_type = LuaType::from_vec(uncorrelated_target_types); + let merged_type = if uncorrelated_target_type.is_never() { + narrowed_correlated_type + } else { + LuaType::from_vec(vec![narrowed_correlated_type, uncorrelated_target_type]) + }; + + Ok(if merged_type == antecedent_type { + ResultTypeOrContinue::Continue + } else { + ResultTypeOrContinue::Result(merged_type) + }) +} + +#[allow(clippy::too_many_arguments)] +fn collect_correlated_types_from_search_root( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + discriminant_decl_id: LuaDeclId, + target_decl_id: LuaDeclId, + condition_position: rowan::TextSize, + search_root_flow_id: FlowId, + narrowed_discriminant_type: &LuaType, +) -> Result<(Vec, Option), InferFailReason> { + let (discriminant_refs, discriminant_has_non_reference_origin) = tree + .get_decl_multi_return_ref_summary_at( + &discriminant_decl_id, + condition_position, + search_root_flow_id, + ); + let (target_refs, target_has_non_reference_origin) = tree.get_decl_multi_return_ref_summary_at( + &target_decl_id, + condition_position, + search_root_flow_id, + ); + if discriminant_refs.is_empty() || target_refs.is_empty() { + return Ok(( + Vec::new(), + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), + )); + } + + let ( + root_matching_target_types, + root_correlated_candidate_types, + has_unmatched_correlated_origin, + ) = collect_matching_correlated_types( + db, + cache, + root, + &discriminant_refs, + &target_refs, + narrowed_discriminant_type, + )?; + if root_matching_target_types.is_empty() { + return Ok(( + Vec::new(), + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), + )); + } + + let root_uncorrelated_target_type = if discriminant_has_non_reference_origin + || target_has_non_reference_origin + || has_unmatched_correlated_origin + { + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) + .ok() + .and_then(|root_type| { + subtract_correlated_candidate_types(db, root_type, &root_correlated_candidate_types) + }) + } else { + None + }; + + Ok((root_matching_target_types, root_uncorrelated_target_type)) +} + +fn subtract_correlated_candidate_types( + db: &DbIndex, + source_type: LuaType, + correlated_candidate_types: &[LuaType], +) -> Option { + let remaining_types = match source_type { + LuaType::Union(union) => union + .into_vec() + .into_iter() + .filter(|member| { + !correlated_candidate_types.iter().any(|correlated_type| { + TypeOps::Union.apply(db, correlated_type, member) == *correlated_type + }) + }) + .collect::>(), + source_type => (!correlated_candidate_types.iter().any(|correlated_type| { + TypeOps::Union.apply(db, correlated_type, &source_type) == *correlated_type + })) + .then_some(source_type) + .into_iter() + .collect(), + }; + + (!remaining_types.is_empty()).then_some(LuaType::from_vec(remaining_types)) +} + +#[allow(clippy::too_many_arguments)] +fn collect_matching_correlated_types( + db: &DbIndex, + cache: &mut LuaInferCache, + root: &LuaChunk, + discriminant_refs: &[crate::DeclMultiReturnRef], + target_refs: &[crate::DeclMultiReturnRef], + narrowed_discriminant_type: &LuaType, +) -> Result<(Vec, Vec, bool), InferFailReason> { + let mut matching_target_types = Vec::new(); + let mut correlated_candidate_types = Vec::new(); + let mut correlated_discriminant_call_expr_ids = HashSet::new(); + let mut correlated_target_call_expr_ids = HashSet::new(); + + for discriminant_ref in discriminant_refs { + let Some((call_expr, signature)) = + infer_signature_for_call_ptr(db, cache, root, &discriminant_ref.call_expr)? + else { + continue; + }; + if signature.return_overloads.is_empty() { + continue; + } + + let overload_rows = instantiate_return_overload_rows(db, cache, call_expr, signature); + let discriminant_call_expr_id = discriminant_ref.call_expr.get_syntax_id(); + + for target_ref in target_refs { + if target_ref.call_expr.get_syntax_id() != discriminant_call_expr_id { + continue; + } + correlated_discriminant_call_expr_ids.insert(discriminant_call_expr_id); + correlated_target_call_expr_ids.insert(target_ref.call_expr.get_syntax_id()); + correlated_candidate_types.extend(overload_rows.iter().map(|overload| { + crate::LuaSignature::get_overload_row_slot(overload, target_ref.return_index) + })); + matching_target_types.extend(overload_rows.iter().filter_map(|overload| { + let discriminant_type = crate::LuaSignature::get_overload_row_slot( + overload, + discriminant_ref.return_index, + ); + if !TypeOps::Intersect + .apply(db, &discriminant_type, narrowed_discriminant_type) + .is_never() + { + return Some(crate::LuaSignature::get_overload_row_slot( + overload, + target_ref.return_index, + )); + } + + None + })); + } + } + + let has_unmatched_correlated_origin = discriminant_refs.iter().any(|discriminant_ref| { + !correlated_discriminant_call_expr_ids.contains(&discriminant_ref.call_expr.get_syntax_id()) + }) || target_refs.iter().any(|target_ref| { + !correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) + }); + Ok(( + matching_target_types, + correlated_candidate_types, + has_unmatched_correlated_origin, + )) +} + +fn infer_signature_for_call_ptr<'a>( + db: &'a DbIndex, + cache: &mut LuaInferCache, + root: &LuaChunk, + call_expr_ptr: &LuaAstPtr, +) -> Result, InferFailReason> { + let Some(call_expr) = call_expr_ptr.to_node(root) else { + return Ok(None); + }; + let Some(prefix_expr) = call_expr.get_prefix_expr() else { + return Ok(None); + }; + let signature_id = match infer_expr(db, cache, prefix_expr)? { + LuaType::Signature(signature_id) => signature_id, + _ => return Ok(None), + }; + let Some(signature) = db.get_signature_index().get(&signature_id) else { + return Ok(None); + }; + + Ok(Some((call_expr, signature))) +} + +fn instantiate_return_overload_rows( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: LuaCallExpr, + signature: &crate::LuaSignature, +) -> Vec> { + let mut rows = Vec::with_capacity(signature.return_overloads.len()); + for overload in &signature.return_overloads { + let type_refs = &overload.type_refs; + let overload_return_type = crate::LuaSignature::row_to_return_type(type_refs.to_vec()); + let instantiated_return_type = if overload_return_type.contain_tpl() { + let overload_func = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + overload_return_type.clone(), + ); + match instantiate_func_generic(db, cache, &overload_func, call_expr.clone()) { + Ok(instantiated) => instantiated.get_ret().clone(), + Err(_) => overload_return_type, + } + } else { + overload_return_type + }; + + rows.push(crate::LuaSignature::return_type_to_row( + instantiated_return_type, + )); + } + + rows +} diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index b922d827a..1a28ecb92 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -1,11 +1,13 @@ mod binary_flow; mod call_flow; +mod correlated_flow; mod index_flow; +use self::correlated_flow::narrow_var_from_return_overload_condition; use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, + DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, semantic::infer::{ VarRefId, narrow::{ @@ -188,6 +190,33 @@ fn get_type_at_name_ref( else { return Ok(ResultTypeOrContinue::Continue); }; + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_discriminant_type = get_type_at_flow( + db, + tree, + cache, + root, + &VarRefId::VarRef(decl_id), + antecedent_flow_id, + )?; + let narrowed_discriminant_type = match condition_flow { + InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_discriminant_type), + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), + }; + + if let ResultTypeOrContinue::Result(result_type) = narrow_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + decl_id, + name_expr.get_position(), + &narrowed_discriminant_type, + )? { + return Ok(ResultTypeOrContinue::Result(result_type)); + } let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { return Ok(ResultTypeOrContinue::Continue); @@ -209,6 +238,32 @@ fn get_type_at_name_ref( ) } +pub(super) fn always_literal_equal(left: &LuaType, right: &LuaType) -> bool { + match (left, right) { + (LuaType::Union(union), other) => union + .into_vec() + .into_iter() + .all(|candidate| always_literal_equal(&candidate, other)), + (other, LuaType::Union(union)) => union + .into_vec() + .into_iter() + .all(|candidate| always_literal_equal(other, &candidate)), + ( + LuaType::StringConst(l) | LuaType::DocStringConst(l), + LuaType::StringConst(r) | LuaType::DocStringConst(r), + ) => l == r, + ( + LuaType::BooleanConst(l) | LuaType::DocBooleanConst(l), + LuaType::BooleanConst(r) | LuaType::DocBooleanConst(r), + ) => l == r, + ( + LuaType::IntegerConst(l) | LuaType::DocIntegerConst(l), + LuaType::IntegerConst(r) | LuaType::DocIntegerConst(r), + ) => l == r, + _ => left == right, + } +} + #[allow(clippy::too_many_arguments)] fn get_type_at_unary_flow( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index da37a0207..66ccda945 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -288,10 +288,7 @@ fn try_infer_decl_initializer_type( }; let expr_type = infer_expr(db, cache, expr.clone())?; - let init_type = match expr_type { - LuaType::Variadic(variadic) => variadic.get_type(0).cloned(), - ty => Some(ty), - }; + let init_type = expr_type.get_result_slot_type(0); Ok(init_type) } diff --git a/crates/emmylua_ls/src/handlers/hover/build_hover.rs b/crates/emmylua_ls/src/handlers/hover/build_hover.rs index 2a3c68e61..accfed646 100644 --- a/crates/emmylua_ls/src/handlers/hover/build_hover.rs +++ b/crates/emmylua_ls/src/handlers/hover/build_hover.rs @@ -335,6 +335,15 @@ pub fn add_signature_ret_description( )); } } + for (i, ret_overload) in signature.return_overloads.iter().enumerate() { + if let Some(description) = ret_overload.description.clone() { + s.push_str(&format!( + "@*return_overload* #{} — {}\n\n", + i + 1, + description + )); + } + } if !s.is_empty() { marked_strings.push(MarkedString::from_markdown(s)); } diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 79871abc3..2402e078e 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -1,9 +1,10 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ - AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaFunctionType, LuaMember, LuaMemberOwner, - LuaSemanticDeclId, LuaType, RenderLevel, TypeSubstitutor, VariadicType, humanize_type, - infer_call_expr_func, instantiate_doc_function, try_extract_signature_id_from_field, + AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, + LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, + TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, + instantiate_func_generic, try_extract_signature_id_from_field, }; use crate::handlers::hover::{ @@ -73,15 +74,17 @@ fn build_function_call_hover( .ok()?; // 根据推断出来的类型确定哪个 semantic_decl 是匹配的 - let mut match_semantic_decl = &semantic_decls.last()?.0; - for (semantic_decl, typ) in semantic_decls.iter() { + let mut matched_decl = semantic_decls.last()?; + for semantic_decl in semantic_decls.iter() { + let (_, typ) = semantic_decl; if let LuaType::DocFunction(f) = typ { if f == &final_type { - match_semantic_decl = semantic_decl; + matched_decl = semantic_decl; break; } } } + let (match_semantic_decl, match_typ) = matched_decl; let function_member = match match_semantic_decl { LuaSemanticDeclId::Member(id) => { @@ -92,15 +95,65 @@ fn build_function_call_hover( }; let is_field = function_member_is_field(db, semantic_decls); - let contents = process_function_type( - builder, - db, - &LuaType::DocFunction(final_type), - function_member, - function_name, - is_local, - is_field, - )?; + let contents = if let LuaType::Signature(signature_id) = match_typ { + let signature = db.get_signature_index().get(signature_id)?; + let base_function = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + signature.get_return_type(), + ); + let instantiated_signature = instantiate_func_generic( + db, + &mut builder.semantic_model.get_cache().borrow_mut(), + &base_function, + call_expr.clone(), + ) + .ok()?; + + if !signature.return_overloads.is_empty() + && final_type.get_async_state() == instantiated_signature.get_async_state() + && final_type.is_colon_define() == instantiated_signature.is_colon_define() + && final_type.is_variadic() == instantiated_signature.is_variadic() + && final_type.get_params() == instantiated_signature.get_params() + { + let return_overloads = + instantiate_call_return_overloads(builder, db, call_expr, signature); + let ret_detail = build_function_return_overload_rows(builder, &return_overloads); + vec![hover_doc_function_type( + builder, + db, + final_type.as_ref(), + function_member, + function_name, + is_local, + is_field, + Vec::new(), + Some(ret_detail), + )] + } else { + process_function_type( + builder, + db, + &LuaType::DocFunction(final_type), + function_member, + function_name, + is_local, + is_field, + )? + } + } else { + process_function_type( + builder, + db, + &LuaType::DocFunction(final_type), + function_member, + function_name, + is_local, + is_field, + )? + }; let description = get_function_description(builder, db, &match_semantic_decl); builder.set_type_description(contents.first()?.clone()); builder.add_description_from_info(description); @@ -218,6 +271,7 @@ fn process_function_type( is_local, is_field, convert_function_return_to_docs(lua_func), + None, ); Some(vec![content]) } @@ -231,23 +285,45 @@ fn process_function_type( signature.get_type_params(), signature.get_return_type(), )); - new_overloads.insert(0, fake_doc_function); + new_overloads.insert(0, fake_doc_function.clone()); let mut contents = Vec::with_capacity(new_overloads.len()); for (i, overload) in new_overloads.iter().enumerate() { - contents.push(hover_doc_function_type( - builder, - db, - overload, - function_member, - function_name, - is_local, - is_field, - if i == 0 { - signature.return_docs.clone() - } else { - convert_function_return_to_docs(overload) - }, - )); + let content = if i == 0 && !signature.return_overloads.is_empty() { + let ret_detail = + build_function_return_overload_rows(builder, &signature.return_overloads); + hover_doc_function_type( + builder, + db, + overload, + function_member, + function_name, + is_local, + is_field, + Vec::new(), + Some(ret_detail), + ) + } else { + hover_doc_function_type( + builder, + db, + overload, + function_member, + function_name, + is_local, + is_field, + if i == 0 { + if signature.return_docs.is_empty() { + convert_function_return_to_docs(fake_doc_function.as_ref()) + } else { + signature.return_docs.clone() + } + } else { + convert_function_return_to_docs(overload) + }, + None, + ) + }; + contents.push(content); } Some(contents) } @@ -281,6 +357,7 @@ fn hover_doc_function_type( is_local: bool, is_field: bool, /* 是否为类字段 */ return_docs: Vec, /* 返回值以此为准 */ + ret_detail: Option, ) -> String { let async_label = match func.get_async_state() { AsyncState::Async => "async ", @@ -372,11 +449,60 @@ fn hover_doc_function_type( } }) .filter(|s| !s.is_empty()) - .collect::>() - .join(", "); + .collect::>(); + + let ret_detail = ret_detail.unwrap_or_else(|| build_function_returns(builder, return_docs)); + format_function_type( + type_label, + async_label, + full_name, + params.join(", "), + ret_detail, + ) +} + +fn instantiate_call_return_overloads( + builder: &HoverBuilder, + db: &DbIndex, + call_expr: &emmylua_parser::LuaCallExpr, + signature: &LuaSignature, +) -> Vec { + let mut cache = builder.semantic_model.get_cache().borrow_mut(); - let ret_detail = build_function_returns(builder, return_docs); - format_function_type(type_label, async_label, full_name, params, ret_detail) + signature + .return_overloads + .iter() + .map(|row| { + let row_return_type = match row.type_refs.len() { + 0 => LuaType::Nil, + 1 => row.type_refs[0].clone(), + _ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()), + }; + let row_function = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + row_return_type, + ); + let instantiated_row = + instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) + .ok() + .map(|func| match func.get_ret() { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Multi(types) => types.clone(), + VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())], + }, + typ => vec![typ.clone()], + }) + .unwrap_or_else(|| row.type_refs.clone()); + + LuaDocReturnOverloadInfo { + type_refs: instantiated_row, + description: row.description.clone(), + } + }) + .collect() } fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { @@ -466,7 +592,6 @@ fn build_function_returns( let type_text = build_function_return_type(builder, return_info, i); if has_multiline { - // 存在返回值名称时使用多行模式 let prefix = if i == 0 { result.push('\n'); "-> ".to_string() @@ -485,14 +610,39 @@ fn build_function_returns( }, type_text, )); + } else if i == 0 { + result.push_str(&format!(" -> {}", type_text)); } else { - // 不存在返回值名称时使用单行模式 - if i == 0 { - result.push_str(&format!(" -> {}", type_text)); - } else { - result.push_str(&format!(", {}", type_text)); - } + result.push_str(&format!(", {}", type_text)); + } + } + + result +} + +fn build_function_return_overload_rows( + builder: &mut HoverBuilder, + return_overloads: &[LuaDocReturnOverloadInfo], +) -> String { + let mut result = String::new(); + + for (row_idx, row) in return_overloads.iter().enumerate() { + if row.type_refs.is_empty() { + continue; + } + + let row_text = row + .type_refs + .iter() + .enumerate() + .map(|(i, typ)| build_return_type_text(builder, typ, i)) + .collect::>() + .join(", "); + + if row_idx == 0 { + result.push('\n'); } + result.push_str(&format!(" -> {}\n", row_text)); } result @@ -503,9 +653,13 @@ fn build_function_return_type( ret_info: &LuaDocReturnInfo, i: usize, ) -> String { + build_return_type_text(builder, &ret_info.type_ref, i) +} + +fn build_return_type_text(builder: &mut HoverBuilder, typ: &LuaType, i: usize) -> String { let type_expansion_count = builder.get_type_expansion_count(); // 在这个过程中可能会设置`type_expansion` - let type_text = hover_humanize_type(builder, &ret_info.type_ref, Some(RenderLevel::Simple)); + let type_text = hover_humanize_type(builder, typ, Some(RenderLevel::Simple)); if builder.get_type_expansion_count() > type_expansion_count { // 重新设置`type_expansion` if let Some(pop_type_expansion) = diff --git a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs index 264767dee..2529771b1 100644 --- a/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs +++ b/crates/emmylua_ls/src/handlers/semantic_token/build_semantic_tokens.rs @@ -164,11 +164,15 @@ fn build_tokens_semantic_token( } } LuaTokenKind::TkTrue | LuaTokenKind::TkFalse | LuaTokenKind::TkNil => { - builder.push_with_modifier( - token, - SemanticTokenType::KEYWORD, - SemanticTokenModifier::READONLY, - ); + if is_doc_type_literal_token(token) { + builder.push(token, SemanticTokenType::TYPE); + } else { + builder.push_with_modifier( + token, + SemanticTokenType::KEYWORD, + SemanticTokenModifier::READONLY, + ); + } } LuaTokenKind::TkComplex | LuaTokenKind::TkInt | LuaTokenKind::TkFloat => { builder.push(token, SemanticTokenType::NUMBER); @@ -201,6 +205,7 @@ fn build_tokens_semantic_token( | LuaTokenKind::TkTagUsing | LuaTokenKind::TkTagSource | LuaTokenKind::TkTagReturnCast + | LuaTokenKind::TkTagReturnOverload | LuaTokenKind::TkTagExport | LuaTokenKind::TkLanguage | LuaTokenKind::TkTagAttribute @@ -294,6 +299,12 @@ fn build_tokens_semantic_token( } } +fn is_doc_type_literal_token(token: &LuaSyntaxToken) -> bool { + token + .parent() + .is_some_and(|parent| parent.kind() == LuaSyntaxKind::TypeLiteral.into()) +} + fn build_node_semantic_token( semantic_model: &SemanticModel, builder: &mut SemanticBuilder, @@ -376,6 +387,7 @@ fn build_node_semantic_token( } } } + LuaAst::LuaDocTagReturnOverload(_) => {} LuaAst::LuaDocTagCast(doc_cast) => { if let Some(target_expr) = doc_cast.get_key_expr() { match target_expr { @@ -740,15 +752,6 @@ fn build_node_semantic_token( } } } - LuaAst::LuaDocLiteralType(literal) => { - if let LuaLiteralToken::Bool(bool_token) = &literal.get_literal()? { - builder.push_with_modifier( - bool_token.syntax(), - SemanticTokenType::KEYWORD, - SemanticTokenModifier::DOCUMENTATION, - ); - } - } LuaAst::LuaDocDescription(description) => { if !emmyrc.semantic_tokens.render_documentation_markup { for token in description.tokens::() { diff --git a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs index ccf172029..ad7e69eb0 100644 --- a/crates/emmylua_ls/src/handlers/test/hover_function_test.rs +++ b/crates/emmylua_ls/src/handlers/test/hover_function_test.rs @@ -178,6 +178,92 @@ mod tests { Ok(()) } + #[gtest] + fn test_return_overload_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!( + ws.check_hover( + r#" + ---@return_overload true, integer + ---@return_overload false, string + local function parse() + end + + local alias = parse + "#, + VirtualHoverResult { + value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```" + .to_string(), + }, + ) + ); + Ok(()) + } + + #[gtest] + fn test_return_overload_description_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@return_overload true, integer success + ---@return_overload false, string failed + local function parse() + end + + local alias = parse + "#, + VirtualHoverResult { + value: "```lua\nlocal function parse()\n -> true, integer\n -> false, string\n\n```\n\n---\n\n@*return_overload* #1 — success\n\n@*return_overload* #2 — failed".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_return_overload_call_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + check!(ws.check_hover( + r#" + ---@class B + local B + + ---@generic T + ---@param x T + ---@return_overload true, T + ---@return_overload false, string + local function parse(x) + end + + parse(B) + "#, + VirtualHoverResult { + value: "```lua\nlocal function parse(x: B)\n -> true, B\n -> false, string\n\n```".to_string(), + }, + )); + Ok(()) + } + + #[gtest] + fn test_pcall_return_overload_hover() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new_with_init_std_lib(); + check!(ws.check_hover( + r#" + --- @param a string + --- @param b table + --- @return_overload false, [string,string] comment + --- @return_overload true, string comment + local function foo(a, b) + end + + local a, b = pcall(foo) + "#, + VirtualHoverResult { + value: "```lua\nfunction pcall(f: sync fun(a: string, b: table) -> ((false|true),((string,string)|string)), a: string, b: table)\n -> true, (false|true), ((string,string)|string)\n -> false, string\n\n```\n\n---\n\n\nCalls function `f` with the given arguments in *protected mode*. This\nmeans that any error inside `f` is not propagated; instead, `pcall` catches\nthe error and returns a status code. Its first result is the status code (a\nboolean), which is true if the call succeeds without errors. In such case,\n`pcall` also returns all results from the call, after this first result. In\ncase of any error, `pcall` returns **false** plus the error message.".to_string(), + }, + )); + Ok(()) + } + #[gtest] fn test_table_field_function_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 81e70deb4..0b60a1e7e 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -90,4 +90,47 @@ m.foo() Ok(()) } + + #[gtest] + fn test_return_overload_tag_is_documentation_keyword() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let data = ws.get_semantic_token_data( + r#"---@return_overload true, integer +"#, + )?; + let tokens = decode(&data); + let keyword = token_type_index(SemanticTokenType::KEYWORD); + let doc = modifier_bitset(&[SemanticTokenModifier::DOCUMENTATION]); + + verify_that!(&tokens, contains(eq(&(0, 4, 15, keyword, doc))))?; + Ok(()) + } + + #[gtest] + fn test_return_overload_rows_highlight_types() -> Result<()> { + let mut ws = ProviderVirtualWorkspace::new(); + let data = ws.get_semantic_token_data(concat!( + "--- @return_overload false, [string,string]\n", + "--- @return_overload true, string\n", + ))?; + let tokens = decode(&data); + let typ = token_type_index(SemanticTokenType::TYPE); + let variable = token_type_index(SemanticTokenType::VARIABLE); + let default_library = modifier_bitset(&[SemanticTokenModifier::DEFAULT_LIBRARY]); + + verify_that!( + &tokens, + all![ + contains(eq(&(0, 21, 5, typ, 0))), + contains(eq(&(0, 29, 6, typ, default_library))), + contains(eq(&(0, 36, 6, typ, default_library))), + contains(eq(&(1, 21, 4, typ, 0))), + contains(eq(&(1, 27, 6, typ, default_library))), + not(contains(eq(&(0, 29, 6, variable, 0)))), + not(contains(eq(&(0, 36, 6, variable, 0)))), + not(contains(eq(&(1, 27, 6, variable, 0)))), + ] + )?; + Ok(()) + } } diff --git a/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs b/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs index cd25ed071..d046257ad 100644 --- a/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs +++ b/crates/emmylua_ls/src/handlers/test/signature_helper_test.rs @@ -36,7 +36,8 @@ mod tests { pcall(readFile, ) "#, VirtualSignatureHelp { - target_label: "pcall(f: sync fun(path: string), path: string): boolean".to_string(), + target_label: + "pcall(f: sync fun(path: string), path: string): (true|false)".to_string(), active_signature: 0, active_parameter: 1, }, diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 5f1cad9f0..b7749873b 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -39,6 +39,7 @@ fn parse_tag_detail(p: &mut LuaDocParser) -> DocParseResult { LuaTokenKind::TkTagType => parse_tag_type(p), LuaTokenKind::TkTagParam => parse_tag_param(p), LuaTokenKind::TkTagReturn => parse_tag_return(p), + LuaTokenKind::TkTagReturnOverload => parse_tag_return_overload(p), LuaTokenKind::TkTagReturnCast => parse_tag_return_cast(p), // other tag LuaTokenKind::TkTagModule => parse_tag_module(p), @@ -369,6 +370,25 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { Ok(m.complete(p)) } +// ---@return_overload true, integer +// ---@return_overload false, string +fn parse_tag_return_overload(p: &mut LuaDocParser) -> DocParseResult { + p.set_lexer_state(LuaDocLexerState::Normal); + let m = p.mark(LuaSyntaxKind::DocTagReturnOverload); + p.bump(); + + parse_type(p)?; + + while p.current_token() == LuaTokenKind::TkComma { + p.bump(); + parse_type(p)?; + } + + p.set_lexer_state(LuaDocLexerState::Description); + parse_description(p); + Ok(m.complete(p)) +} + // ---@return_cast // ---@return_cast else fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { diff --git a/crates/emmylua_parser/src/grammar/doc/test.rs b/crates/emmylua_parser/src/grammar/doc/test.rs index 2b2a60edd..68596e5d1 100644 --- a/crates/emmylua_parser/src/grammar/doc/test.rs +++ b/crates/emmylua_parser/src/grammar/doc/test.rs @@ -132,6 +132,61 @@ Syntax(Chunk)@0..163 assert_ast_eq!(code, result); } + #[test] + fn test_return_overload_tag() { + let code = r#" + ---@return_overload true, integer + "#; + let result = r#" +Syntax(Chunk)@0..51 + Syntax(Block)@0..51 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..42 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagReturnOverload)@13..42 + Token(TkTagReturnOverload)@13..28 "return_overload" + Token(TkWhitespace)@28..29 " " + Syntax(TypeLiteral)@29..33 + Token(TkTrue)@29..33 "true" + Token(TkComma)@33..34 "," + Token(TkWhitespace)@34..35 " " + Syntax(TypeName)@35..42 + Token(TkName)@35..42 "integer" + Token(TkEndOfLine)@42..43 "\n" + Token(TkWhitespace)@43..51 " " + "#; + + assert_ast_eq!(code, result); + } + + #[test] + fn test_return_overload_tag_does_not_parse_named_returns() { + let code = r#" + ---@return_overload true ok, integer + "#; + let result = r#" +Syntax(Chunk)@0..54 + Syntax(Block)@0..54 + Token(TkEndOfLine)@0..1 "\n" + Token(TkWhitespace)@1..9 " " + Syntax(Comment)@9..45 + Token(TkDocStart)@9..13 "---@" + Syntax(DocTagReturnOverload)@13..33 + Token(TkTagReturnOverload)@13..28 "return_overload" + Token(TkWhitespace)@28..29 " " + Syntax(TypeLiteral)@29..33 + Token(TkTrue)@29..33 "true" + Token(TkWhitespace)@33..34 " " + Syntax(DocDescription)@34..45 + Token(TkDocDetail)@34..45 "ok, integer" + Token(TkEndOfLine)@45..46 "\n" + Token(TkWhitespace)@46..54 " " + "#; + + assert_ast_eq!(code, result); + } + #[test] fn test_class_doc() { let code = r#" diff --git a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs index 8ef520370..d48718cca 100644 --- a/crates/emmylua_parser/src/kind/lua_syntax_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_syntax_kind.rs @@ -69,6 +69,7 @@ pub enum LuaSyntaxKind { DocTagType, DocTagParam, DocTagReturn, + DocTagReturnOverload, DocTagGeneric, DocTagSee, DocTagDeprecated, diff --git a/crates/emmylua_parser/src/kind/lua_token_kind.rs b/crates/emmylua_parser/src/kind/lua_token_kind.rs index f3a6a09c1..0f5e3ce1f 100644 --- a/crates/emmylua_parser/src/kind/lua_token_kind.rs +++ b/crates/emmylua_parser/src/kind/lua_token_kind.rs @@ -109,35 +109,36 @@ pub enum LuaTokenKind { TkTagAlias, // alias TkTagModule, // module - TkTagField, // field - TkTagType, // type - TkTagParam, // param - TkTagReturn, // return - TkTagOverload, // overload - TkTagGeneric, // generic - TkTagSee, // see - TkTagDeprecated, // deprecated - TkTagAsync, // async - TkTagCast, // cast - TkTagOther, // other - TkTagVisibility, // public private protected package - TkTagReadonly, // readonly - TkTagDiagnostic, // diagnostic - TkTagMeta, // meta - TkTagVersion, // version - TkTagAs, // as - TkTagNodiscard, // nodiscard - TkTagOperator, // operator - TkTagMapping, // mapping - TkTagNamespace, // namespace - TkTagUsing, // using - TkTagSource, // source - TkTagReturnCast, // return cast - TkTagExport, // export - TkLanguage, // language - TKTagSchema, // schema - TkTagAttribute, // attribute - TkCallGeneric, // call generic. function_name--[[@]](...) + TkTagField, // field + TkTagType, // type + TkTagParam, // param + TkTagReturn, // return + TkTagOverload, // overload + TkTagGeneric, // generic + TkTagSee, // see + TkTagDeprecated, // deprecated + TkTagAsync, // async + TkTagCast, // cast + TkTagOther, // other + TkTagVisibility, // public private protected package + TkTagReadonly, // readonly + TkTagDiagnostic, // diagnostic + TkTagMeta, // meta + TkTagVersion, // version + TkTagAs, // as + TkTagNodiscard, // nodiscard + TkTagOperator, // operator + TkTagMapping, // mapping + TkTagNamespace, // namespace + TkTagUsing, // using + TkTagSource, // source + TkTagReturnCast, // return cast + TkTagReturnOverload, // return overload + TkTagExport, // export + TkLanguage, // language + TKTagSchema, // schema + TkTagAttribute, // attribute + TkCallGeneric, // call generic. function_name--[[@]](...) TkDocOr, // | TkDocAnd, // & diff --git a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs index 02a59688e..14c8d2083 100644 --- a/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs +++ b/crates/emmylua_parser/src/lexer/lua_doc_lexer.rs @@ -712,6 +712,7 @@ fn to_tag(text: &str) -> LuaTokenKind { "param" => LuaTokenKind::TkTagParam, "return" => LuaTokenKind::TkTagReturn, "return_cast" => LuaTokenKind::TkTagReturnCast, + "return_overload" => LuaTokenKind::TkTagReturnOverload, "generic" => LuaTokenKind::TkTagGeneric, "see" => LuaTokenKind::TkTagSee, "overload" => LuaTokenKind::TkTagOverload, diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index 9ebfbc5cd..d5b4ac9ba 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -22,6 +22,7 @@ pub enum LuaDocTag { Type(LuaDocTagType), Param(LuaDocTagParam), Return(LuaDocTagReturn), + ReturnOverload(LuaDocTagReturnOverload), Overload(LuaDocTagOverload), Field(LuaDocTagField), Module(LuaDocTagModule), @@ -58,6 +59,7 @@ impl LuaAstNode for LuaDocTag { LuaDocTag::Type(it) => it.syntax(), LuaDocTag::Param(it) => it.syntax(), LuaDocTag::Return(it) => it.syntax(), + LuaDocTag::ReturnOverload(it) => it.syntax(), LuaDocTag::Overload(it) => it.syntax(), LuaDocTag::Field(it) => it.syntax(), LuaDocTag::Module(it) => it.syntax(), @@ -97,6 +99,7 @@ impl LuaAstNode for LuaDocTag { || kind == LuaSyntaxKind::DocTagAttribute || kind == LuaSyntaxKind::DocTagParam || kind == LuaSyntaxKind::DocTagReturn + || kind == LuaSyntaxKind::DocTagReturnOverload || kind == LuaSyntaxKind::DocTagOverload || kind == LuaSyntaxKind::DocTagField || kind == LuaSyntaxKind::DocTagModule @@ -153,6 +156,9 @@ impl LuaAstNode for LuaDocTag { LuaSyntaxKind::DocTagReturn => { Some(LuaDocTag::Return(LuaDocTagReturn::cast(syntax).unwrap())) } + LuaSyntaxKind::DocTagReturnOverload => Some(LuaDocTag::ReturnOverload( + LuaDocTagReturnOverload::cast(syntax).unwrap(), + )), LuaSyntaxKind::DocTagOverload => Some(LuaDocTag::Overload( LuaDocTagOverload::cast(syntax).unwrap(), )), @@ -595,6 +601,47 @@ impl LuaDocTagReturn { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LuaDocTagReturnOverload { + syntax: LuaSyntaxNode, +} + +impl LuaAstNode for LuaDocTagReturnOverload { + fn syntax(&self) -> &LuaSyntaxNode { + &self.syntax + } + + fn can_cast(kind: LuaSyntaxKind) -> bool + where + Self: Sized, + { + kind == LuaSyntaxKind::DocTagReturnOverload + } + + fn cast(syntax: LuaSyntaxNode) -> Option + where + Self: Sized, + { + if Self::can_cast(syntax.kind().into()) { + Some(Self { syntax }) + } else { + None + } + } +} + +impl LuaDocDescriptionOwner for LuaDocTagReturnOverload {} + +impl LuaDocTagReturnOverload { + pub fn get_first_type(&self) -> Option { + self.child() + } + + pub fn get_types(&self) -> LuaAstChildren { + self.children() + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct LuaDocTagOverload { syntax: LuaSyntaxNode, diff --git a/crates/emmylua_parser/src/syntax/node/mod.rs b/crates/emmylua_parser/src/syntax/node/mod.rs index 13c0daf17..9ae79c24c 100644 --- a/crates/emmylua_parser/src/syntax/node/mod.rs +++ b/crates/emmylua_parser/src/syntax/node/mod.rs @@ -67,6 +67,7 @@ pub enum LuaAst { LuaDocTagType(LuaDocTagType), LuaDocTagParam(LuaDocTagParam), LuaDocTagReturn(LuaDocTagReturn), + LuaDocTagReturnOverload(LuaDocTagReturnOverload), LuaDocTagOverload(LuaDocTagOverload), LuaDocTagField(LuaDocTagField), LuaDocTagModule(LuaDocTagModule), @@ -159,6 +160,7 @@ impl LuaAstNode for LuaAst { LuaAst::LuaDocTagType(node) => node.syntax(), LuaAst::LuaDocTagParam(node) => node.syntax(), LuaAst::LuaDocTagReturn(node) => node.syntax(), + LuaAst::LuaDocTagReturnOverload(node) => node.syntax(), LuaAst::LuaDocTagOverload(node) => node.syntax(), LuaAst::LuaDocTagField(node) => node.syntax(), LuaAst::LuaDocTagModule(node) => node.syntax(), @@ -259,6 +261,7 @@ impl LuaAstNode for LuaAst { | LuaSyntaxKind::DocTagType | LuaSyntaxKind::DocTagParam | LuaSyntaxKind::DocTagReturn + | LuaSyntaxKind::DocTagReturnOverload | LuaSyntaxKind::DocTagOverload | LuaSyntaxKind::DocTagField | LuaSyntaxKind::DocTagModule @@ -378,6 +381,9 @@ impl LuaAstNode for LuaAst { LuaSyntaxKind::DocTagReturn => { LuaDocTagReturn::cast(syntax).map(LuaAst::LuaDocTagReturn) } + LuaSyntaxKind::DocTagReturnOverload => { + LuaDocTagReturnOverload::cast(syntax).map(LuaAst::LuaDocTagReturnOverload) + } LuaSyntaxKind::DocTagOverload => { LuaDocTagOverload::cast(syntax).map(LuaAst::LuaDocTagOverload) } diff --git a/docs/emmylua_doc/annotations_CN/README.md b/docs/emmylua_doc/annotations_CN/README.md index 3db3e0741..07e15ddce 100644 --- a/docs/emmylua_doc/annotations_CN/README.md +++ b/docs/emmylua_doc/annotations_CN/README.md @@ -27,6 +27,7 @@ ### 函数注解 - [`@param`](./param.md) - 参数定义 - [`@return`](./return.md) - 返回值定义 +- `@return_overload`(见 [`@return`](./return.md))- 关联返回元组 - [`@overload`](./overload.md) - 函数重载 - [`@async`](./async.md) - 异步函数标记 - [`@nodiscard`](./nodiscard.md) - 不可忽略返回值 @@ -122,6 +123,7 @@ end | `@field` | 字段定义 | `---@field name string` | | `@param` | 参数定义 | `---@param name string` | | `@return` | 返回值定义 | `---@return boolean` | +| `@return_overload` | 关联返回元组 | `---@return_overload true, T` | | `@type` | 类型声明 | `---@type string` | | `@generic` | 泛型定义 | `---@generic T` | | `@overload` | 函数重载 | `---@overload fun(x: number): number` | diff --git a/docs/emmylua_doc/annotations_CN/return.md b/docs/emmylua_doc/annotations_CN/return.md index a10da1c9a..80b167100 100644 --- a/docs/emmylua_doc/annotations_CN/return.md +++ b/docs/emmylua_doc/annotations_CN/return.md @@ -14,6 +14,9 @@ -- 多返回值 ---@return <类型1> [名称1] [描述1] ---@return <类型2> [名称2] [描述2] + +-- 关联返回行(每一行代表一种返回元组) +---@return_overload <类型1>, <类型2>[, <类型3>...] ``` ## 示例 @@ -174,6 +177,53 @@ for id, userName in iterateUsers() do end ``` +## 返回重载行(`@return_overload`) + +`@return_overload` 用于定义“关联”的多返回值行。每一条注解代表一种可能的返回元组。 +这对状态/结果模式(例如 `pcall` 风格代码)非常有用。 + +当多个局部变量来自同一次函数调用时,对某个返回槽位的条件判断 +(真假判断、字面量相等判断,或 `type()` 守卫)会联动收窄同一返回行中的其他槽位。 + +```lua +---@generic T, E +---@param ok boolean +---@param success T +---@param failure E +---@return boolean +---@return T|E +---@return_overload true, T +---@return_overload false, E +local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure +end + +local cond ---@type boolean +local ok, result = pick(cond, 1, "error") + +if not ok then + error(result) -- result: string +end + +local value = result -- value: integer +``` + +`@return_overload` 同样支持泛型和可变尾部: + +```lua +---@generic T, R +---@param f fun(...: T...): R... +---@param ... T... +---@return_overload true, R... +---@return_overload false, string +local function wrap(f, ...) end +``` + +你可以保留 `@return` 作为宽泛声明,再用 `@return_overload` 提供关联敏感的推断信息。 + ## 特性 1. **多返回值支持** @@ -182,3 +232,4 @@ end 4. **函数返回值** 5. **异步返回值** 6. **条件返回值** +7. **关联返回行收窄(`@return_overload`)** diff --git a/docs/emmylua_doc/annotations_EN/README.md b/docs/emmylua_doc/annotations_EN/README.md index e7f7434dc..355d04b7c 100644 --- a/docs/emmylua_doc/annotations_EN/README.md +++ b/docs/emmylua_doc/annotations_EN/README.md @@ -29,6 +29,7 @@ The following notation symbols are used in annotation syntax descriptions: ### Function Annotations - [`@param`](./param.md) - Parameter definition - [`@return`](./return.md) - Return value definition +- `@return_overload` (see [`@return`](./return.md)) - Correlated return tuples - [`@overload`](./overload.md) - Function overload - [`@async`](./async.md) - Async function marker - [`@nodiscard`](./nodiscard.md) - Non-discardable return value @@ -124,6 +125,7 @@ end | `@field` | Field definition | `---@field name string` | | `@param` | Parameter definition | `---@param name string` | | `@return` | Return value definition | `---@return boolean` | +| `@return_overload` | Correlated return tuples | `---@return_overload true, T` | | `@type` | Type declaration | `---@type string` | | `@generic` | Generic definition | `---@generic T` | | `@overload` | Function overload | `---@overload fun(x: number): number` | diff --git a/docs/emmylua_doc/annotations_EN/return.md b/docs/emmylua_doc/annotations_EN/return.md index 6eaeb6366..de7e2ce3c 100644 --- a/docs/emmylua_doc/annotations_EN/return.md +++ b/docs/emmylua_doc/annotations_EN/return.md @@ -14,6 +14,9 @@ Define return value types and description information for functions. -- Multiple return values ---@return [name1] [description1] ---@return [name2] [description2] + +-- Correlated return rows (one row per possible return tuple) +---@return_overload , [, ...] ``` ## Examples @@ -174,6 +177,54 @@ for id, userName in iterateUsers() do end ``` +## Return Overload Rows (`@return_overload`) + +`@return_overload` defines correlated multi-return rows. Each annotation line represents one possible return tuple. +This is useful for status/result APIs (for example `pcall`-style code). + +When multiple local variables are assigned from the same call, condition checks on one return slot +(truthy/falsy checks, literal equality checks, or `type()` guards) narrow correlated slots from +the same row. + +```lua +---@generic T, E +---@param ok boolean +---@param success T +---@param failure E +---@return boolean +---@return T|E +---@return_overload true, T +---@return_overload false, E +local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure +end + +local cond ---@type boolean +local ok, result = pick(cond, 1, "error") + +if not ok then + error(result) -- result: string +end + +local value = result -- value: integer +``` + +`@return_overload` also supports generic and variadic tails: + +```lua +---@generic T, R +---@param f fun(...: T...): R... +---@param ... T... +---@return_overload true, R... +---@return_overload false, string +local function wrap(f, ...) end +``` + +You can keep `@return` as the broad declaration and add `@return_overload` rows for correlation-sensitive inference. + ## Features 1. **Multiple return value support** @@ -182,3 +233,4 @@ end 4. **Function return values** 5. **Async return values** 6. **Conditional return values** +7. **Correlated return row narrowing (`@return_overload`)**