diff --git a/examples/last_will.inherit.wit b/examples/last_will.inherit.wit index 16752030..88b46ab7 100644 --- a/examples/last_will.inherit.wit +++ b/examples/last_will.inherit.wit @@ -1,6 +1,10 @@ { - "INHERIT_OR_NOT": { - "value": "Left(0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f)", - "type": "Either>" + "ACTION": { + "value": "1", + "type": "u8" + }, + "INHERITOR_SIG": { + "value": "0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f", + "type": "Signature" } } diff --git a/examples/last_will.simf b/examples/last_will.simf index 9790a1cf..aab2bf73 100644 --- a/examples/last_will.simf +++ b/examples/last_will.simf @@ -40,12 +40,22 @@ fn refresh_spend(hot_sig: Signature) { recursive_covenant(); } +enum Action { + Inherit=1, + ColdSpend =2, + HotSpend =3, +} + fn main() { - match witness::INHERIT_OR_NOT { - Left(inheritor_sig: Signature) => inherit_spend(inheritor_sig), - Right(cold_or_hot: Either) => match cold_or_hot { - Left(cold_sig: Signature) => cold_spend(cold_sig), - Right(hot_sig: Signature) => refresh_spend(hot_sig), - }, + match witness::ACTION { + Action::Inherit => { + let inheritor_sig: Signature = witness::INHERITOR_SIG; + inherit_spend(inheritor_sig)} , + Action::ColdSpend => { + let cold_sig: Signature = witness::COLD_SIG; + cold_spend(cold_sig) }, + Action::HotSpend => { + let hot_sig: Signature = witness::HOT_SIG; + refresh_spend(hot_sig) }, } } diff --git a/src/ast.rs b/src/ast.rs index fca0947f..7a326e9f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::sync::Arc; @@ -18,7 +18,7 @@ use crate::str::{AliasName, FunctionName, Identifier, ModuleName, WitnessName}; use crate::types::{ AliasedType, ResolvedType, StructuralType, TypeConstructible, TypeDeconstructible, UIntType, }; -use crate::value::{UIntValue, Value}; +use crate::value::{UIntValue, Value, ValueConstructible}; use crate::witness::{Parameters, WitnessTypes}; use crate::{driver, impl_eq_hash, parse}; @@ -582,6 +582,34 @@ impl JetHinter for ElementsJetHinter { } } +/// A single enum variant after analysis: its name and u8 discriminant, without source span. +#[derive(Clone, Debug, Eq, PartialEq)] +struct ResolvedEnumVariant { + name: Identifier, + discriminant: u8, +} + +/// The resolved definition of an enum as stored in [`Scope`]: +/// a list of [`ResolvedEnumVariant`]s in declaration order. +#[derive(Clone, Debug, Eq, PartialEq)] +struct EnumBinding { + variants: Arc<[ResolvedEnumVariant]>, +} + +impl EnumBinding { + fn new(variants: Arc<[ResolvedEnumVariant]>) -> Self { + Self { variants } + } + + fn variants(&self) -> &[ResolvedEnumVariant] { + &self.variants + } + + fn contains_variant(&self, name: &Identifier) -> bool { + self.variants.iter().any(|v| &v.name == name) + } +} + /// Scope for generating the abstract syntax tree. /// /// The scope is used for: @@ -594,6 +622,7 @@ struct Scope { variables: Vec>, aliases: HashMap, ResolvedType>, aliases_table: SymbolTable, + enums: HashMap, EnumBinding>, parameters: HashMap, witnesses: HashMap, functions: HashMap, CustomFunction>, @@ -625,6 +654,7 @@ impl Scope { variables: Vec::new(), aliases: HashMap::new(), aliases_table, + enums: HashMap::new(), parameters: HashMap::new(), witnesses: HashMap::new(), functions: HashMap::new(), @@ -779,6 +809,24 @@ impl Scope { Ok(()) } + pub fn insert_enum( + &mut self, + name: AliasName, + variants: Arc<[ResolvedEnumVariant]>, + ) -> Result<(), Error> { + let plug = (name.clone(), self.file_id); + if self.enums.contains_key(&plug) { + return Err(Error::RedefinedAlias { name }); + } + self.enums.insert(plug, EnumBinding::new(variants)); + Ok(()) + } + + pub fn get_enum(&self, name: &AliasName) -> Option<&EnumBinding> { + let plug = (name.clone(), self.file_id); + self.enums.get(&plug) + } + /// Insert a parameter into the global map. /// /// ## Errors @@ -984,6 +1032,57 @@ impl AbstractSyntaxTree for Item { Error::UseKeywordIsNotSupported, *use_decl.span(), )), + parse::Item::EnumDeclaration(decl) => { + scope.file_id = decl.file_id(); + let n = decl.variants().len(); + if n < 2 { + return Err(Error::Grammar { + msg: format!("enum '{}' must have at least 2 variants", decl.name()), + }) + .with_span(decl); + } + let mut sorted: Vec<&parse::EnumVariant> = decl.variants().iter().collect(); + sorted.sort_by_key(|v| v.discriminant()); + for w in sorted.windows(2) { + if w[0].discriminant() == w[1].discriminant() { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate discriminant {}", + decl.name(), + w[0].discriminant() + ), + }) + .with_span(decl); + } + } + let mut seen_names = HashSet::new(); + for v in decl.variants() { + if !seen_names.insert(v.name()) { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate variant name '{}'", + decl.name(), + v.name() + ), + }) + .with_span(decl); + } + } + let variants: Arc<[ResolvedEnumVariant]> = sorted + .iter() + .map(|v| ResolvedEnumVariant { + name: v.name().clone(), + discriminant: v.discriminant(), + }) + .collect(); + scope + .insert_alias(decl.name().clone(), AliasedType::from(UIntType::U8)) + .with_span(decl)?; + scope + .insert_enum(decl.name().clone(), variants) + .with_span(decl)?; + Ok(Self::TypeAlias) + } parse::Item::Module => Ok(Self::Module), }; @@ -1298,6 +1397,75 @@ impl AbstractSyntaxTree for SingleExpression { parse::SingleExpressionInner::Match(match_) => { Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)? } + parse::SingleExpressionInner::EnumMatch(enum_match) => { + let arms = enum_match.arms(); + let span = *enum_match.span(); + if arms.is_empty() { + return Err(Error::Grammar { + msg: "enum match has no arms".to_string(), + }) + .with_span(span); + } + let enum_name = match arms[0].pattern() { + MatchPattern::EnumVariant(name, _) => name.clone(), + _ => unreachable!("EnumMatch arms have EnumVariant patterns"), + }; + let binding = scope + .get_enum(&enum_name) + .ok_or_else(|| Error::UndefinedAlias { + name: enum_name.clone(), + }) + .with_span(span)?; + let mut arm_map: HashMap<&Identifier, &parse::Expression> = HashMap::new(); + for arm in arms { + let MatchPattern::EnumVariant(arm_enum_name, variant) = arm.pattern() else { + unreachable!("EnumMatch arms have EnumVariant patterns") + }; + if arm_enum_name != &enum_name { + return Err(Error::Grammar { + msg: format!( + "all match arms must use the same enum; expected '{}', found '{}'", + enum_name, arm_enum_name + ), + }) + .with_span(span); + } + if !binding.contains_variant(variant) { + return Err(Error::Grammar { + msg: format!( + "variant '{}' is not defined in enum '{}'", + variant, enum_name + ), + }) + .with_span(span); + } + if arm_map.insert(variant, arm.expression()).is_some() { + return Err(Error::Grammar { + msg: format!("duplicate arm for variant '{}'", variant), + }) + .with_span(span); + } + } + if arm_map.len() != binding.variants().len() { + return Err(Error::Grammar { + msg: format!( + "enum match on '{}' must cover all {} variants", + enum_name, + binding.variants().len() + ), + }) + .with_span(span); + } + let ordered_arms: Vec<(&parse::Expression, u8)> = binding + .variants() + .iter() + .map(|v| (arm_map[&v.name], v.discriminant)) + .collect(); + let u8_ty = ResolvedType::from(UIntType::U8); + let scrutinee = + Expression::analyze(enum_match.scrutinee(), &u8_ty, scope).map(Arc::new)?; + desugar_enum_arms_u8(&ordered_arms, scrutinee, ty, scope, span)? + } }; Ok(Self { @@ -1705,6 +1873,152 @@ impl AbstractSyntaxTree for ModuleAssignment { } } +/// Desugar an N-arm enum match (u8 discriminant) into a `jet::eq_8` comparison chain. +fn desugar_enum_arms_u8( + arms: &[(&parse::Expression, u8)], + scrutinee: Arc, + expected_ty: &ResolvedType, + scope: &mut Scope, + span: Span, +) -> Result { + debug_assert!(arms.len() >= 2); + + let u8_ty = ResolvedType::from(UIntType::U8); + + // Bind the scrutinee to a fresh variable to avoid witness-reuse errors. + let disc_ident = Identifier::from_str_unchecked("__disc_"); + + scope.push_scope(); + scope.insert_variable(disc_ident.clone(), u8_ty.clone()); + + let analyzed_arms: Vec<(Arc, u8)> = arms + .iter() + .map(|(e, disc)| { + scope.push_scope(); + let result = + Expression::analyze(e, expected_ty, scope).map(|expr| (Arc::new(expr), *disc)); + scope.pop_scope(); + result + }) + .collect::, _>>()?; + + let chain = build_u8_chain(&disc_ident, &analyzed_arms, expected_ty, &u8_ty, span); + scope.pop_scope(); + + // Wrap in block: { let __disc_N: u8 = scrutinee; } + let chain_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: chain, + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }); + let assign_stmt = Statement::Assignment(Assignment { + pattern: Pattern::Identifier(disc_ident), + expression: (*scrutinee).clone(), + span, + }); + Ok(SingleExpressionInner::Expression(Arc::new(Expression { + inner: ExpressionInner::Block(Arc::from([assign_stmt]), Some(chain_expr)), + ty: expected_ty.clone(), + span, + }))) +} + +/// Build a nested bool-`Match` chain for u8 discriminant dispatch. +/// +/// Every variant, including the last, is guarded by an `eq8` comparison. +/// A `panic!()` on the final false branch ensures that any undeclared +/// discriminant value causes the script to fail rather than silently +/// executing the last arm. +/// +/// `if eq8(disc, d[0]) { arms[0] } else if eq8(disc, d[1]) { arms[1] } ... else if eq8(disc, d[N-1]) { arms[N-1] } else { panic!() }` +fn build_u8_chain( + disc_ident: &Identifier, + arms: &[(Arc, u8)], + expected_ty: &ResolvedType, + u8_ty: &ResolvedType, + span: Span, +) -> SingleExpressionInner { + debug_assert!(!arms.is_empty()); + + let (arm_expr, discriminant) = &arms[0]; + let disc_var = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Variable(disc_ident.clone()), + ty: u8_ty.clone(), + span, + }), + ty: u8_ty.clone(), + span, + }); + let const_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Constant(Value::u8(*discriminant)), + ty: u8_ty.clone(), + span, + }), + ty: u8_ty.clone(), + span, + }); + let eq8_expr = Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Call(Call { + name: CallName::Jet(Box::new(Elements::Eq8)), + args: Arc::from([(*disc_var).clone(), (*const_expr).clone()]), + span, + }), + ty: ResolvedType::boolean(), + span, + }), + ty: ResolvedType::boolean(), + span, + }); + + let false_branch = if arms.len() == 1 { + // Last arm: an undeclared discriminant must not silently execute any arm. + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Call(Call { + name: CallName::Panic, + args: Arc::from([]), + span, + }), + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }) + } else { + let rest_inner = build_u8_chain(disc_ident, &arms[1..], expected_ty, u8_ty, span); + Arc::new(Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: rest_inner, + ty: expected_ty.clone(), + span, + }), + ty: expected_ty.clone(), + span, + }) + }; + + SingleExpressionInner::Match(Match { + scrutinee: eq8_expr, + left: MatchArm { + pattern: MatchPattern::False, + expression: false_branch, + }, + right: MatchArm { + pattern: MatchPattern::True, + expression: arm_expr.clone(), + }, + span, + }) +} + impl AsRef for Assignment { fn as_ref(&self) -> &Span { &self.span @@ -1747,6 +2061,242 @@ impl AsRef for ModuleAssignment { } } +#[cfg(test)] +mod enum_tests { + use super::{ElementsJetHinter, Program}; + use crate::driver::tests::setup_graph; + use crate::error::ErrorCollector; + + fn analyze(src: &str) -> Result<(), String> { + let (graph, _ids, _dir) = setup_graph(vec![("main.simf", src)]); + let mut handler = ErrorCollector::new(); + let driver_prog = graph + .linearize_and_build(&mut handler) + .unwrap() + .expect("driver build should succeed"); + Program::analyze(&driver_prog, Box::new(ElementsJetHinter::new())) + .map(|_| ()) + .map_err(|e| e.to_string()) + } + + #[test] + fn enum_declaration_registers_type_alias() { + let result = analyze( + "enum Color { Red = 1, Green = 2 } + fn main() { let _x: Color = witness::C; }", + ); + assert!( + result.is_ok(), + "enum should register as type alias: {result:?}" + ); + } + + #[test] + fn enum_match_on_function_return() { + let result = analyze( + "enum Dir { Left = 1, Right = 2 } + fn wrap(d: Dir) -> Dir { d } + fn main() { + match wrap(witness::D) { + Dir::Left => assert!(jet::eq_32(0, 0)), + Dir::Right => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "enum match on function return should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_2_variants_desugars() { + let result = analyze( + "enum Coin { Heads = 1, Tails = 2 } + fn main() { + match witness::C { + Coin::Heads => assert!(jet::eq_32(0, 0)), + Coin::Tails => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "2-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_3_variants_desugars() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "3-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_arms_sorted_by_discriminant() { + // Arms in reverse discriminant order should still compile correctly. + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::C => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "arms in any order should compile: {result:?}" + ); + } + + #[test] + fn enum_too_few_variants_is_error() { + let result = analyze("enum Bad { Only = 1 } fn main() {}"); + assert!(result.is_err(), "single-variant enum should error"); + assert!( + result.unwrap_err().contains("at least 2 variants"), + "expected 'at least 2 variants' in error" + ); + } + + #[test] + fn enum_duplicate_discriminant_is_error() { + let result = analyze("enum Bad { A = 1, B = 1 } fn main() {}"); + assert!(result.is_err(), "duplicate discriminant should error"); + assert!( + result.unwrap_err().contains("duplicate discriminant"), + "expected 'duplicate discriminant' in error" + ); + } + + #[test] + fn enum_duplicate_variant_name_is_error() { + let result = analyze("enum Bad { A = 1, A = 2 } fn main() {}"); + assert!(result.is_err(), "duplicate variant name should error"); + assert!( + result.unwrap_err().contains("duplicate variant name"), + "expected 'duplicate variant name' in error" + ); + } + + #[test] + fn enum_duplicate_name_is_error() { + use crate::error::ErrorCollector; + let (graph, _ids, _dir) = setup_graph(vec![( + "main.simf", + "enum Color { Red = 1, Green = 2 } + enum Color { Blue = 1, Yellow = 2 } + fn main() {}", + )]); + let mut handler = ErrorCollector::new(); + let program_option = graph.linearize_and_build(&mut handler).unwrap(); + assert!( + program_option.is_none(), + "duplicate enum name should cause build failure" + ); + } + + #[test] + fn enum_match_missing_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "missing arm should error"); + assert!( + result.unwrap_err().contains("must cover all"), + "expected 'must cover all' in error" + ); + } + + #[test] + fn enum_match_unknown_variant_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::X => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "unknown variant should error"); + assert!( + result.unwrap_err().contains("not defined in enum"), + "expected 'not defined in enum' in error" + ); + } + + #[test] + fn enum_match_duplicate_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "duplicate arm should error"); + assert!( + result.unwrap_err().contains("duplicate arm"), + "expected 'duplicate arm' in error" + ); + } + + #[test] + fn enum_match_mixed_enum_names_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + enum Other { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Other::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "mixed enum names should error"); + assert!( + result.unwrap_err().contains("same enum"), + "expected 'same enum' in error" + ); + } + + #[test] + fn enum_match_undefined_enum_is_error() { + let result = analyze( + "fn main() { + match witness::P { + Unknown::A => assert!(jet::eq_32(0, 0)), + Unknown::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "undefined enum should error"); + } +} + #[cfg(test)] mod alias_scope_regression_tests { use super::{ElementsJetHinter, Program}; diff --git a/src/driver/resolve_order.rs b/src/driver/resolve_order.rs index ad5463d5..93ac73f3 100644 --- a/src/driver/resolve_order.rs +++ b/src/driver/resolve_order.rs @@ -67,6 +67,12 @@ impl Program { } } + parse::Item::EnumDeclaration(decl) => { + if let Err(err) = register_enum_alias(decl, &mut aliases, MAIN_MODULE) { + handler.push(err.with_content(content.clone())); + continue; + } + } // Safe to skip: `Use` items are handled earlier in the loop, and `Module` currently has no functionality. parse::Item::Module | parse::Item::Use(_) => continue, } @@ -237,6 +243,12 @@ impl DependencyGraph { } } + parse::Item::EnumDeclaration(decl) => { + if let Err(err) = register_enum_alias(decl, &mut aliases, source_id) { + handler.push(err.with_source(source.clone())); + continue; + } + } // Safe to skip: `Use` items are handled earlier in the loop, and `Module` currently has no functionality. parse::Item::Module | parse::Item::Use(_) => continue, } @@ -432,6 +444,28 @@ fn register_type_alias( Ok(()) } +fn register_enum_alias( + item: &mut parse::EnumDeclaration, + tracker: &mut NamespaceTracker, + source_id: usize, +) -> Result<(), RichError> { + item.set_file_id(source_id); + + let name = item.name().clone(); + let local_id = (name.clone(), source_id); + + if tracker.memo.contains(&local_id) { + return Err(RichError::new( + Error::RedefinedAlias { name }, + *item.as_ref(), + )); + } + + tracker.memo.insert(local_id); + tracker.resolutions[source_id].insert(name, item.visibility().clone()); + Ok(()) +} + fn register_function( item: &mut Function, tracker: &mut NamespaceTracker, diff --git a/src/lexer.rs b/src/lexer.rs index 06d63adc..fb1b729a 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -21,6 +21,7 @@ pub enum Token<'src> { Mod, Const, Match, + Enum, Crate, // Control symbols @@ -80,6 +81,7 @@ impl<'src> fmt::Display for Token<'src> { Token::Mod => write!(f, "mod"), Token::Const => write!(f, "const"), Token::Match => write!(f, "match"), + Token::Enum => write!(f, "enum"), Token::Crate => write!(f, "{}", CRATE_STR), Token::Arrow => write!(f, "->"), @@ -156,6 +158,7 @@ pub fn lexer<'src>( "mod" => Token::Mod, "const" => Token::Const, "match" => Token::Match, + "enum" => Token::Enum, CRATE_STR => Token::Crate, "true" => Token::Bool(true), "false" => Token::Bool(false), @@ -259,7 +262,8 @@ pub fn lex<'src>(input: &'src str) -> (Option>, Vec assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Select variant A via its u8 discriminant. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(1)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_3_variants() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 0, B = 2, C = 5 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Select variant C via its u8 discriminant. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(5)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_function_return() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Dir { Left = 1, Right = 2 } + fn wrap(d: Dir) -> Dir { d } + fn main() { + match wrap(witness::D) { + Dir::Left => assert!(jet::eq_32(0, 0)), + Dir::Right => assert!(jet::eq_32(0, 0)), + } + } + "#; + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("D"), Value::u8(1)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_invalid_discriminant_fails() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Discriminant 0 is not declared in the enum; the script must fail. + for bad in [0u8, 4, 99, 255] { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(bad)); + let result = TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .run(); + assert!( + result.is_err(), + "discriminant {bad} is not declared; execution should fail but succeeded" + ); + } + } + + #[test] + fn missing_witness_on_live_branch_errors() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" +enum Branch { A = 1, B = 2 } +fn main() { + match witness::SELECTOR { + Branch::A => assert!(jet::is_zero_32(witness::A)), + Branch::B => assert!(jet::is_zero_32(witness::B)), + } +} +"#; + let env = crate::dummy_env::dummy(); + + // SELECTOR = 1 (Branch::A) → branch A taken; B is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(1)); + map.insert(WitnessName::from_str_unchecked("A"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect("B is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; A is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + map.insert(WitnessName::from_str_unchecked("B"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect("A is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; B is missing and live → satisfy errors + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + // B is intentionally not provided + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + let err = compiled + .satisfy_with_env(WitnessValues::from(map), Some(&env)) + .expect_err("B is on the executed branch and missing; satisfy should fail"); + assert!( + err.contains('B'), + "error message should mention witness B, got: {err}" + ); + } + } + #[test] #[cfg(feature = "serde")] fn hodl_vault() { diff --git a/src/named.rs b/src/named.rs index 9de1b6e7..c4ad36d3 100644 --- a/src/named.rs +++ b/src/named.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use simplicity::dag::{InternalSharing, PostOrderIterItem}; @@ -243,6 +244,67 @@ pub fn populate_witnesses( node.convert::(&mut populator) } +/// Walk the `commit` tree and the `pruned` redeem tree in parallel, checking that +/// no zero-filled witness (tracked in `zero_filled`) appears on a non-pruned branch. +/// +/// Pruned branches are indicated by `Fail` nodes in the pruned tree. When a `Case` +/// node is pruned to `AssertL` or `AssertR`, only the surviving child is recursed into. +pub fn check_surviving_witnesses( + commit: &CommitNode, + pruned: &Arc, + zero_filled: &HashSet, +) -> Result<(), String> { + match (commit.inner(), pruned.inner()) { + // Pruned branch or unreachable fail node — no witnesses to check + (_, Inner::Fail(_)) | (Inner::Fail(_), _) => Ok(()), + // Witness node on a live branch — error if it was zero-filled + (Inner::Witness(name), Inner::Witness(_)) => { + if zero_filled.contains(name) { + Err(format!( + "Witness `{name}` is used on the executed branch but has no assigned value" + )) + } else { + Ok(()) + } + } + // Leaf nodes with no witness children + (Inner::Iden, _) | (Inner::Unit, _) | (Inner::Jet(_), _) | (Inner::Word(_), _) => Ok(()), + // Single-child nodes — recurse into the child + (Inner::InjL(cc), Inner::InjL(cp)) + | (Inner::InjR(cc), Inner::InjR(cp)) + | (Inner::Take(cc), Inner::Take(cp)) + | (Inner::Drop(cc), Inner::Drop(cp)) => check_surviving_witnesses(cc, cp, zero_filled), + // Assert nodes — one live child, one CMR; recurse into the live child + (Inner::AssertL(cc, _), Inner::AssertL(cp, _)) + | (Inner::AssertR(_, cc), Inner::AssertR(_, cp)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + // Two-child nodes — recurse into both + (Inner::Comp(cl, cr), Inner::Comp(pl, pr)) | (Inner::Pair(cl, cr), Inner::Pair(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case: both branches live + (Inner::Case(cl, cr), Inner::Case(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case pruned to AssertL: only left branch survived + (Inner::Case(cl, _), Inner::AssertL(pl, _)) => { + check_surviving_witnesses(cl, pl, zero_filled) + } + // Case pruned to AssertR: only right branch survived + (Inner::Case(_, cr), Inner::AssertR(_, pr)) => { + check_surviving_witnesses(cr, pr, zero_filled) + } + // Disconnect — not used in SimplicityHL; handle defensively + (Inner::Disconnect(cc, _), Inner::Disconnect(cp, _)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + _ => unreachable!("unexpected structural mismatch between commit and pruned trees"), + } +} + // This awkward construction is required by rust-simplicity to implement WitnessConstructible // for Node>. See // https://docs.rs/simplicity-lang/latest/simplicity/node/trait.WitnessConstructible.html#foreign-impls diff --git a/src/parse.rs b/src/parse.rs index 42b5dac2..9a8b25cb 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -58,6 +58,8 @@ pub enum Item { /// An import declaration (e.g., `use math::add`) that brings another /// [`Item`] into the current scope. Use(UseDecl), + /// An enum declaration. + EnumDeclaration(EnumDeclaration), /// A module, which is ignored. Module, } @@ -403,6 +405,91 @@ impl TypeAlias { impl_eq_hash!(TypeAlias; name, ty); +/// A single variant in an enum declaration. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct EnumVariant { + name: Identifier, + discriminant: u8, + span: Span, +} + +impl EnumVariant { + pub fn name(&self) -> &Identifier { + &self.name + } + + pub fn discriminant(&self) -> u8 { + self.discriminant + } +} + +impl AsRef for EnumVariant { + fn as_ref(&self) -> &Span { + &self.span + } +} + +/// An enum declaration. +#[derive(Clone, Debug)] +pub struct EnumDeclaration { + file_id: usize, + visibility: Visibility, + name: AliasName, + variants: Arc<[EnumVariant]>, + span: Span, +} + +impl EnumDeclaration { + pub fn file_id(&self) -> usize { + self.file_id + } + + pub fn set_file_id(&mut self, file_id: usize) { + self.file_id = file_id; + } + + pub fn visibility(&self) -> &Visibility { + &self.visibility + } + + pub fn name(&self) -> &AliasName { + &self.name + } + + pub fn variants(&self) -> &[EnumVariant] { + &self.variants + } +} + +impl_eq_hash!(EnumDeclaration; name, variants); + +impl AsRef for EnumDeclaration { + fn as_ref(&self) -> &Span { + &self.span + } +} + +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for EnumDeclaration { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let file_id = u.int_in_range(0..=3)?; + let visibility = Visibility::arbitrary(u)?; + let name = AliasName::arbitrary(u)?; + let len = u.int_in_range(2..=8)?; + let variants = (0..len) + .map(|_| EnumVariant::arbitrary(u)) + .collect::>>()?; + Ok(Self { + file_id, + visibility, + name, + variants, + span: Span::DUMMY, + }) + } +} + /// An expression is something that returns a value. #[derive(Clone, Debug)] pub struct Expression { @@ -505,6 +592,8 @@ pub enum SingleExpressionInner { Expression(Arc), /// Match expression over a sum type Match(Match), + /// Match expression over a named enum type + EnumMatch(EnumMatch), /// Tuple wrapper expression Tuple(Arc<[Expression]>), /// Array wrapper expression @@ -560,6 +649,30 @@ impl Match { impl_eq_hash!(Match; scrutinee, left, right); +/// Match expression over a named enum type (N arms, N ≥ 2). +#[derive(Clone, Debug)] +pub struct EnumMatch { + scrutinee: Arc, + arms: Arc<[MatchArm]>, + span: Span, +} + +impl EnumMatch { + pub fn scrutinee(&self) -> &Expression { + &self.scrutinee + } + + pub fn arms(&self) -> &[MatchArm] { + &self.arms + } + + pub fn span(&self) -> &Span { + &self.span + } +} + +impl_eq_hash!(EnumMatch; scrutinee, arms); + /// Arm of a match expression. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct MatchArm { @@ -595,6 +708,8 @@ pub enum MatchPattern { False, /// Match true value (no binding). True, + /// Match a named enum variant (no payload binding). + EnumVariant(AliasName, Identifier), } impl MatchPattern { @@ -604,7 +719,10 @@ impl MatchPattern { MatchPattern::Left(i, _) | MatchPattern::Right(i, _) | MatchPattern::Some(i, _) => { Some(i) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } @@ -614,7 +732,10 @@ impl MatchPattern { MatchPattern::Left(i, ty) | MatchPattern::Right(i, ty) | MatchPattern::Some(i, ty) => { Some((i, ty)) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } } @@ -683,6 +804,7 @@ impl fmt::Display for Item { Self::TypeAlias(alias) => write!(f, "{alias}"), Self::Function(function) => write!(f, "{function}"), Self::Use(use_declaration) => write!(f, "{use_declaration}"), + Self::EnumDeclaration(decl) => write!(f, "{decl}"), // The parse tree contains no information about the contents of modules. // We print a random empty module `mod witness {}` here // so that `from_string(to_string(x)) = x` holds for all trees `x`. @@ -793,6 +915,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + EnumMatch(&'a EnumMatch), } impl TreeLike for ExprTree<'_> { @@ -833,6 +956,7 @@ impl TreeLike for ExprTree<'_> { | S::Expression(l) => Tree::Unary(Self::Expression(l)), S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::EnumMatch(enum_match) => Tree::Unary(Self::EnumMatch(enum_match)), S::Tuple(elements) | S::Array(elements) | S::List(elements) => { Tree::Nary(elements.iter().map(Self::Expression).collect()) } @@ -843,6 +967,16 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::EnumMatch(enum_match) => Tree::Nary( + std::iter::once(Self::Expression(enum_match.scrutinee())) + .chain( + enum_match + .arms() + .iter() + .map(|arm| Self::Expression(arm.expression())), + ) + .collect(), + ), } } } @@ -906,7 +1040,7 @@ impl fmt::Display for ExprTree<'_> { write!(f, ")")?; } }, - S::Call(..) | S::Match(..) => {} + S::Call(..) | S::Match(..) | S::EnumMatch(..) => {} S::Tuple(tuple) => { if data.n_children_yielded == 0 { write!(f, "(")?; @@ -957,6 +1091,18 @@ impl fmt::Display for ExprTree<'_> { write!(f, ",\n}}")?; } }, + Self::EnumMatch(enum_match) => { + let n = data.n_children_yielded; + if n == 0 { + write!(f, "match ")?; + } else if n == 1 { + write!(f, "{{\n{} => ", enum_match.arms()[0].pattern())?; + } else if n <= enum_match.arms().len() { + write!(f, ",\n{} => ", enum_match.arms()[n - 1].pattern())?; + } else { + write!(f, ",\n}}")?; + } + } } } @@ -1029,7 +1175,24 @@ impl fmt::Display for MatchPattern { MatchPattern::Some(i, ty) => write!(f, "Some({i}: {ty})"), MatchPattern::False => write!(f, "false"), MatchPattern::True => write!(f, "true"), + MatchPattern::EnumVariant(enum_name, variant) => write!(f, "{enum_name}::{variant}"), + } + } +} + +impl fmt::Display for EnumDeclaration { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}enum {} {{", self.visibility(), self.name())?; + for variant in self.variants() { + write!(f, " {} = {},", variant.name(), variant.discriminant())?; } + write!(f, " }}") + } +} + +impl fmt::Display for EnumMatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", ExprTree::EnumMatch(self)) } } @@ -1366,9 +1529,16 @@ impl ChumskyParse for Item { let func_parser = Function::parser().map(Item::Function); let type_parser = TypeAlias::parser().map(Item::TypeAlias); let use_parser = UseDecl::parser().map(Item::Use); + let enum_parser = EnumDeclaration::parser().map(Item::EnumDeclaration); let mod_parser = Module::parser().map(|_| Item::Module); - choice((func_parser, use_parser, type_parser, mod_parser)) + choice(( + func_parser, + use_parser, + type_parser, + enum_parser, + mod_parser, + )) } } @@ -1781,6 +1951,62 @@ impl ChumskyParse for TypeAlias { } } +impl ChumskyParse for EnumDeclaration { + fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone + where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + { + let visibility = just(Token::Pub) + .to(Visibility::Public) + .or_not() + .map(Option::unwrap_or_default); + + let discriminant = just(Token::Eq) + .ignore_then(select! { Token::DecLiteral(d) => d }) + .try_map(|d, span| { + d.as_inner().parse::().map_err(|_| { + RichError::new( + Error::Grammar { + msg: format!( + "enum discriminant '{}' is out of range (must be 0-255)", + d.as_inner() + ), + }, + span, + ) + }) + }); + + let variant = + Identifier::parser() + .then(discriminant) + .map_with(|(name, discriminant), e| EnumVariant { + name, + discriminant, + span: e.span(), + }); + + let variants = variant + .separated_by(just(Token::Comma)) + .allow_trailing() + .collect::>() + .delimited_by(just(Token::LBrace), just(Token::RBrace)) + .map(Arc::from); + + visibility + .then_ignore(just(Token::Enum)) + .then(AliasName::parser()) + .then(variants) + .map_with(|((visibility, name), variants), e| Self { + file_id: MAIN_MODULE, + visibility, + name, + variants, + span: e.span(), + }) + } +} + impl ChumskyParse for Expression { fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone where @@ -1902,7 +2128,7 @@ impl SingleExpression { let call = Call::parser(expr.clone()).map(SingleExpressionInner::Call); - let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match); + let match_expr = match_expr_parser(expr.clone()); let variable = Identifier::parser().map(SingleExpressionInner::Variable); @@ -1957,129 +2183,120 @@ impl ChumskyParse for MatchPattern { } } -impl MatchArm { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - MatchPattern::parser() - .then_ignore(just(Token::FatArrow)) - .then(expr.map(Arc::new)) - .then(just(Token::Comma).or_not()) - .validate(|((pattern, expression), comma), e, emitter| { - let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); - - if !is_block && comma.is_none() { - emitter.emit( - Error::Grammar { - msg: "Missing ',' after a match arm that isn't block expression" - .to_string(), - } - .with_span(e.span()), - ); - } - - Self { - pattern, - expression, - } - }) - } -} - -impl Match { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - let scrutinee = expr.clone().map(Arc::new); - - let arm_recovery = any() - .filter(|t| !matches!(t, Token::Comma | Token::RBrace)) - .ignored() - .or(nested_delimiters( - Token::LBrace, - Token::RBrace, - [ - (Token::LParen, Token::RParen), - (Token::LBracket, Token::RBracket), - ], - |_| (), - ) - .ignored()) - .repeated() - .map_with(|(), _| None); - - let arm_parser = MatchArm::parser(expr.clone()) - .map(Some) - .recover_with(via_parser(arm_recovery.clone())); - - let arms = delimited_with_recovery( - arm_parser.clone().then(arm_parser.clone()), - Token::LBrace, - Token::RBrace, - |_| (None, None), - ); +/// Parser for `match` expressions. +/// +/// Handles both binary match (exactly 2 arms: Left/Right, None/Some, false/true) and enum match +/// (2+ arms using `EnumName::Variant` patterns). Dispatches to [`Match`] or [`EnumMatch`] based +/// on the patterns found. +fn match_expr_parser<'tokens, 'src: 'tokens, I, E>( + expr: E, +) -> impl Parser<'tokens, I, SingleExpressionInner, ParseError<'src>> + Clone +where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, +{ + let scrutinee = expr.clone().map(Arc::new); - just(Token::Match) - .ignore_then(scrutinee) - .then(arms) - .validate(|(scrutinee, arms), e, emit| match arms { - (Some(first), Some(second)) => { - let (left, right) = match (&first.pattern, &second.pattern) { - (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), - (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), - - (MatchPattern::None, MatchPattern::Some(..)) => (first, second), - (MatchPattern::Some(..), MatchPattern::None) => (second, first), - - (MatchPattern::False, MatchPattern::True) => (first, second), - (MatchPattern::True, MatchPattern::False) => (second, first), - - (p1, p2) => { - emit.emit( - Error::IncompatibleMatchArms { - first: p1.clone(), - second: p2.clone(), - } - .with_span(e.span()), - ); - (first, second) - } - }; + // Enum variant pattern: `EnumName::VariantName`. + // Binary keywords are excluded so choice() works without backtracking: + // when the ident is Left/Right/Some/None the select! guard fails without consuming the token. + let enum_variant_pattern = + select! { Token::Ident(name) if name != "Left" && name != "Right" && name != "Some" && name != "None" => AliasName::from_str_unchecked(name) } + .then_ignore(just(Token::DoubleColon)) + .then(select! { Token::Ident(v) => Identifier::from_str_unchecked(v) }) + .map(|(enum_name, variant)| MatchPattern::EnumVariant(enum_name, variant)); + + let combined_pattern = choice((enum_variant_pattern, MatchPattern::parser())); + + // No recover_with here: repeated() stops naturally when arm_parser fails. + // Outer delimited_with_recovery handles the block-level recovery. + let arm_parser = combined_pattern + .then_ignore(just(Token::FatArrow)) + .then(expr.clone().map(Arc::new)) + .then(just(Token::Comma).or_not()) + .validate(|((pattern, expression), comma), e, emitter| { + let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); + if !is_block && comma.is_none() { + emitter.emit( + Error::Grammar { + msg: "Missing ',' after a match arm that isn't block expression" + .to_string(), + } + .with_span(e.span()), + ); + } + MatchArm { + pattern, + expression, + } + }); + + let arms = delimited_with_recovery( + arm_parser.repeated().collect::>(), + Token::LBrace, + Token::RBrace, + |_| vec![], + ); + + just(Token::Match) + .ignore_then(scrutinee) + .then(arms) + .validate(|(scrutinee, arms), e, emit| { + let all_enum = arms + .iter() + .all(|a| matches!(a.pattern, MatchPattern::EnumVariant(..))); + + if all_enum && arms.len() >= 2 { + return SingleExpressionInner::EnumMatch(EnumMatch { + scrutinee, + arms: Arc::from(arms), + span: e.span(), + }); + } - Self { - scrutinee, - left, - right, - span: e.span(), + // Binary match: exactly 2 non-enum arms. + let fallback_arm = MatchArm { + expression: Arc::new(Expression::empty(Span::new(0, 0))), + pattern: MatchPattern::False, + }; + let (first, second) = if arms.len() == 2 { + let mut it = arms.into_iter(); + (it.next().unwrap(), it.next().unwrap()) + } else { + emit.emit( + Error::Grammar { + msg: "binary match requires exactly 2 arms".to_string(), } - } - _ => { - let match_arm_fallback = MatchArm { - expression: Arc::new(Expression::empty(Span::new(0, 0))), - pattern: MatchPattern::False, - }; + .with_span(e.span()), + ); + (fallback_arm.clone(), fallback_arm) + }; - let (left, right) = ( - arms.0.unwrap_or(match_arm_fallback.clone()), - arms.1.unwrap_or(match_arm_fallback.clone()), + let (left, right) = match (&first.pattern, &second.pattern) { + (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), + (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), + (MatchPattern::None, MatchPattern::Some(..)) => (first, second), + (MatchPattern::Some(..), MatchPattern::None) => (second, first), + (MatchPattern::False, MatchPattern::True) => (first, second), + (MatchPattern::True, MatchPattern::False) => (second, first), + (p1, p2) => { + emit.emit( + Error::IncompatibleMatchArms { + first: p1.clone(), + second: p2.clone(), + } + .with_span(e.span()), ); - Self { - scrutinee, - left, - right, - span: e.span(), - } + (first, second) } + }; + SingleExpressionInner::Match(Match { + scrutinee, + left, + right, + span: e.span(), }) - } + }) } impl ChumskyParse for Module { @@ -2529,4 +2746,92 @@ mod test { assert_eq!(program.to_string(), format!("{input}\n")); } } + + fn parse_item(input: &str) -> Item { + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + program.items().first().expect("expected one item").clone() + } + + #[test] + fn test_enum_declaration_basic() { + let item = parse_item("enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration, got {item:?}"); + }; + assert_eq!(decl.name().as_inner(), "Path"); + assert_eq!(decl.variants().len(), 3); + assert_eq!(decl.variants()[0].name().as_inner(), "Inherit"); + assert_eq!(decl.variants()[0].discriminant(), 1); + assert_eq!(decl.variants()[1].name().as_inner(), "ColdSpend"); + assert_eq!(decl.variants()[1].discriminant(), 2); + assert_eq!(decl.variants()[2].name().as_inner(), "RefreshSpend"); + assert_eq!(decl.variants()[2].discriminant(), 3); + } + + #[test] + fn test_enum_declaration_pub() { + let item = parse_item("pub enum Color { Red = 0, Green = 1, Blue = 2, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!(decl.visibility(), &Visibility::Public); + assert_eq!(decl.name().as_inner(), "Color"); + } + + #[test] + fn test_enum_declaration_display_round_trip() { + let input = "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"; + let item = parse_item(input); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!( + decl.to_string(), + "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }" + ); + } + + #[test] + fn test_enum_match_parses() { + let input = "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, Path::RefreshSpend => 2, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors(source, &mut errors); + assert!(program.is_some(), "should parse without errors"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } + + #[test] + fn test_enum_match_produces_enum_match_node() { + let input = + "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, } }"; + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + // Walk the tree looking for an EnumMatch node + let has_enum_match = program.items().iter().any(|item| { + if let Item::Function(f) = item { + format!("{f}").contains("Path::Inherit") + } else { + false + } + }); + assert!(has_enum_match, "expected EnumMatch in the parse tree"); + } + + #[test] + fn test_binary_match_still_works_after_enum_parser_change() { + let input = "fn main() { let x: bool = true; match x { true => 1, false => 0, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors(source, &mut errors); + assert!(program.is_some(), "binary match should still parse"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } } diff --git a/src/value.rs b/src/value.rs index 1ccb38bd..47df6a4a 100644 --- a/src/value.rs +++ b/src/value.rs @@ -648,6 +648,34 @@ impl Value { }; Ok(ret) } + + /// Create a zero value of the given type. + /// + /// For integers, this is 0. For sum types, this is `Left(zero)`. For options, this is `None`. + /// For tuples and arrays, each element is zero. For lists, this is the empty list. + pub fn zero(ty: &ResolvedType) -> Self { + match ty.as_inner() { + TypeInner::Boolean => Self::from(false), + TypeInner::UInt(uint_ty) => match uint_ty { + UIntType::U1 => Self::u1(0), + UIntType::U2 => Self::u2(0), + UIntType::U4 => Self::u4(0), + UIntType::U8 => Self::u8(0), + UIntType::U16 => Self::u16(0), + UIntType::U32 => Self::u32(0), + UIntType::U64 => Self::u64(0), + UIntType::U128 => Self::u128(0), + UIntType::U256 => Self::u256(U256::from_byte_array([0u8; 32])), + }, + TypeInner::Either(left, right) => Self::left(Self::zero(left), (**right).clone()), + TypeInner::Option(inner) => Self::none((**inner).clone()), + TypeInner::Tuple(elements) => Self::tuple(elements.iter().map(|e| Self::zero(e))), + TypeInner::Array(el_ty, size) => { + Self::array((0..*size).map(|_| Self::zero(el_ty)), (**el_ty).clone()) + } + TypeInner::List(el_ty, bound) => Self::list([], (**el_ty).clone(), *bound), + } + } } impl Value { diff --git a/src/witness.rs b/src/witness.rs index 3fd878d4..8b7532f2 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -128,6 +128,23 @@ impl WitnessValues { Ok(()) } + + /// Return a copy of these witness values with zero values inserted for any witness declared + /// in `types` that has no assigned value. Witnesses already present are unchanged. + /// + /// This is used before populating Simplicity witness nodes: all nodes must be filled, even + /// those on branches that will be pruned away and never executed. + pub fn fill_missing(&self, types: &WitnessTypes) -> (Self, HashSet) { + let mut map: HashMap = (*self.0).clone(); + let mut zero_filled = HashSet::new(); + for (name, ty) in types.iter() { + if !map.contains_key(name) { + map.insert(name.shallow_clone(), Value::zero(ty)); + zero_filled.insert(name.shallow_clone()); + } + } + (Self::from(map), zero_filled) + } } impl ParseFromStr for ResolvedType { @@ -217,7 +234,7 @@ mod tests { use crate::error::ErrorCollector; use crate::parse::ParseFromStr; use crate::value::ValueConstructible; - use crate::{ast, driver, parse, CompiledProgram, SatisfiedProgram}; + use crate::{ast, driver, parse, CompiledProgram, ResolvedType, SatisfiedProgram}; #[test] fn witness_reuse() { @@ -291,6 +308,45 @@ fn main() { } } + #[test] + fn fill_missing_zero_fills_and_tracks_missing_witnesses() { + let ty = ResolvedType::parse_from_str("u32").unwrap(); + let witness_types = WitnessTypes::from(HashMap::from([ + (WitnessName::from_str_unchecked("A"), ty.clone()), + (WitnessName::from_str_unchecked("B"), ty.clone()), + (WitnessName::from_str_unchecked("C"), ty.clone()), + ])); + + // A is explicitly provided with value zero (same value fill_missing would insert). + // B and C are not provided at all. + let provided = WitnessValues::from(HashMap::from([( + WitnessName::from_str_unchecked("A"), + Value::u32(0), + )])); + + let (filled, zero_filled) = provided.fill_missing(&witness_types); + + // Explicitly-provided witnesses must NOT be tracked as zero-filled, + // even when their value happens to be zero. + assert!( + !zero_filled.contains(&WitnessName::from_str_unchecked("A")), + "A was explicitly provided; must not appear in zero_filled" + ); + // Missing witnesses must be tracked so check_surviving_witnesses can error. + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("B")), + "B was not provided; must appear in zero_filled" + ); + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("C")), + "C was not provided; must appear in zero_filled" + ); + // All three must now have values in the filled map. + assert!(filled.get(&WitnessName::from_str_unchecked("A")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("B")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("C")).is_some()); + } + #[test] fn witness_to_string() { let witness = WitnessValues::from(HashMap::from([ diff --git a/test-data/last_will.json b/test-data/last_will.json index 3dace1a8..e54d5914 100644 --- a/test-data/last_will.json +++ b/test-data/last_will.json @@ -1,4 +1,4 @@ { - "program": "5wnQKEGJsWVABAmKSEGCrynMGLpUF69BbvwQFoAuY+y1ngQJfqSPabfWRZ9K3F2jdRYYBitLzfMz987l3WKtAxSudDhYOBTf5tlucUbKz5QK2LfAvMA1kChBh+DHCpJAk4cziqISK6EzABFXCwYvClhPYGFQusJfripGQssOAVt34AhgGJAoSQbgJxuBig/FJwqFobGHNddy8HoTqejIHGcv8bcleUZT57KmW1Vp7LXaMUR4qMQ4YBiE3n41BAOBgcOJFAGQOJwuLAGkHHAHHpBiBQbkHacYEf5RB7X1tMEVAbpXAfNhcd45LjO88p6usCblccJ7lByDCchhcRcOA4GJxgBwcGIGlafkwigGSWMMQSTRPhfidUim1MchFg2+ZsIYB8RO84Db5ByMCcj0kCT4YnM/BVazZBMsdgY/lS0WYcYfsNRJVmhtHQmf/PVrNEOe4wYBisgAAAAIGkKhambcTmCIv9QHGkTTAYXN78l9PRKwkaP7L+QgG2hgCZadW734oMAxC4AwcwLvfahR91ofRxdIEhoraXiTCljMruIwlAG9G26fy7ABhgGL4cEObxOhhI4BnviN4uejwVYdGCizvg8pDe+f7r9U2pQklHLAwDhITlckiyAAAAAAUKvhqObQQ6CxT1VyVCCLZUfrJhqhp/qbNkpewATHlLgTDwBZJgwDELSQkUM6IZzkLPP/t8aZ/NfTm5pw5IQ0J/duRiaFMIlp35sJi5uVAYBgObvJWcC0jz5LzKXn0/Nn/OQnPezJTiq+w46I+xAB5sdEwwDWQKEkH1m5aCgWDNug5tpVnanSODCL5dAHG0YZYXL7/GhLKIDD2ZYRxW4obTlfDAMQtRvoQT9zeJ/JSg1zVnOQ+dSDqBncb+M9zPuT55FUoWyEHDxeBVkAAAAAg4AFwB8NhzSoNzoMc1Ae/7Z55CGHj0gOG/ZVBjb5kbDqY2kAJh07WGAYhcGISKEGB4nAwHMDLqyIAl/t8mPC4rRGEfM1lVH7CAxNfptKDrOKJGwAHAYBrC4iNL6lcZ9HTU3CRlRyCoGVZzuEYuSi/jp604WRZrD9L3pYeQ4FCcTgcuhcLAcvAcWAchQuRgDmABzNhOZwDmDFyvAcwoOVQHLcHleBy6B5fgc34A==", + "program": "56XQKEGJsAEECwZtoIQKD6U2AECBYANKBQfQmwAwQLABwFAoSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABDoFCDE2n5MIoBkljDEEk0T4X4nVIptTHIRYNvmbCGAfETvOA2+QQJigw+gUJIEnDmcVRCRXQmYAIq4WDF4UsJ7AwqF1hL9cVIyFlhwCtu/AEMAxIFCSD7zcBFB+AigWTYw5rruXg9CdT0ZA4zl/jbkryjKfPZUy2qtPZa7RiiPFRiHDAMQm0/IQEA4CBwwkgSfDE5n4KrWbIJljsDH8qWizDjD9hqJKs0No6Ez/56tZohz3GDAMVkAAAABA0hULUzbicwRF/qA40iaYDC5vfkvp6JWEjR/ZfyEA20MATLTq3e/FBgGIXAGDmBd77UKPutD6OLpAkNFbS8SYUsZldxGEoA3o23T+XYAMMAxfDghzeJ0MJHAM98RvFz0eCrDowUWd8HlIb3z/dfqm1KEko5YGAcJCcpEkWQAAAAAChV8NRzaCHQWKequSoQRbKj9ZMNUNP9TZslL2ACY8pcCYeALJMGAYhaSEihnRDOchZ5/9vjTP5r6c3NOHJCGhP7tyMTQphEtO/NhMXNyMDAMBzd5KzgWkefJeZS8+n5s/5yE572ZKcVX2HHRH2IAPNjomGAayBQkg+s3JgUCwZt0HNtKs7U6RwYRfLoA42jDLC5ff40JZRAYezLCOK3FDacr4YBiFqN9CCfubxP5KUGuas5yHzqQdQM7jfxnuZ9yfPIqlC2Qg4eLwKsgAAAAEHAAuAPhsOaVBudBjmoD3/bPPIQw8ekBw37KoMbfMjYdTG0gBMOnawwDELgxCRQgwPE4GA5gZdWRAEv9vkx4XFaIwj5msqo/YQGJr9NpQdZxRI2AA4DANYXERpfUrjPo6am4SMqOQVAyrOdwjFyUX8dPWnCyLNYfpe9LDyHAoTicDlCLhYDlGDiwDkKFyMAcpwcwATmBA5VC5UAOVgOYoDmPA5kgeZYDmaOJzOm5lrTjAj/KIPa+tpgioDdK4D5sLjvHJcZ3nlPV1gTcrjhPcoOZYJzMi5a8yAHLkFAzA0g7AOa84nNkbmjsWVABzRhOaZJCDBV5TmDF0qC9egt34IC0AXMfZazwIEv1JHtNvrIs+lbi7RuosMAxWl5vmZ++dy7rFWgYpXOhwsHApv82y3OKNlZ8oFbFvgXmAayBQbmvPzbHCoDmVSKAOZoGkLcA5nQcKA4cBxADxGBztgc8w=", "witness": null }