diff --git a/graduated-rebalancer/src/lib.rs b/graduated-rebalancer/src/lib.rs index 1a8b7c6..7eaa802 100644 --- a/graduated-rebalancer/src/lib.rs +++ b/graduated-rebalancer/src/lib.rs @@ -78,6 +78,11 @@ pub trait TrustedWallet: Send + Sync { &self, method: PaymentMethod, amount: Amount, ) -> Pin> + Send + '_>>; + /// Estimate the fee for making a payment using the trusted wallet + fn estimate_fee( + &self, method: PaymentMethod, amount: Amount, + ) -> Pin> + Send + '_>>; + /// Wait for a payment success notification fn await_payment_success( &self, payment_hash: [u8; 32], @@ -272,10 +277,38 @@ where /// Perform a rebalance from trusted to lightning wallet async fn do_trusted_rebalance_locked(&self, params: TriggerParams) { - let transfer_amt = params.amount; + let mut transfer_amt = params.amount; log_info!(self.logger, "Initiating rebalance"); - if let Ok(inv) = self.ln_wallet.get_bolt11_invoice(Some(transfer_amt)).await { + if let Ok(mut inv) = self.ln_wallet.get_bolt11_invoice(Some(transfer_amt)).await { + if let Ok(fee) = self + .trusted + .estimate_fee(PaymentMethod::LightningBolt11(inv.clone()), transfer_amt) + .await + { + if fee >= transfer_amt { + log_error!( + self.logger, + "Rebalance trusted transaction fee {fee:?} exceeds amount {transfer_amt:?}", + ); + return; + } + + if transfer_amt.saturating_add(fee) > params.amount { + transfer_amt = params.amount.saturating_sub(fee); + match self.ln_wallet.get_bolt11_invoice(Some(transfer_amt)).await { + Ok(reduced_inv) => inv = reduced_inv, + Err(e) => { + log_error!( + self.logger, + "Failed to create fee-adjusted rebalance invoice: {e:?}", + ); + return; + }, + } + } + } + log_debug!( self.logger, "Attempting to pay invoice {inv} to rebalance for {transfer_amt:?}", diff --git a/orange-sdk/src/trusted_wallet/cashu/mod.rs b/orange-sdk/src/trusted_wallet/cashu/mod.rs index f522d55..03bb949 100644 --- a/orange-sdk/src/trusted_wallet/cashu/mod.rs +++ b/orange-sdk/src/trusted_wallet/cashu/mod.rs @@ -216,7 +216,10 @@ impl TrustedWalletInterface for Cashu { })?; // The fee is in the quote - convert_amount(quote.fee_reserve, &self.unit) + let quote_fee = convert_amount(quote.fee_reserve, &self.unit)?; + let input_fee = + self.estimate_input_fee(quote.amount + quote.fee_reserve).await?; + Ok(quote_fee.saturating_add(input_fee)) }, PaymentMethod::LightningBolt12(offer) => { let quote = self @@ -230,7 +233,10 @@ impl TrustedWalletInterface for Cashu { })?; // The fee is in the quote - convert_amount(quote.fee_reserve, &self.unit) + let quote_fee = convert_amount(quote.fee_reserve, &self.unit)?; + let input_fee = + self.estimate_input_fee(quote.amount + quote.fee_reserve).await?; + Ok(quote_fee.saturating_add(input_fee)) }, PaymentMethod::OnChain(_) => Err(TrustedError::UnsupportedOperation( "Cashu mint does not support onchain".to_owned(), @@ -850,6 +856,55 @@ impl Cashu { Ok(()) } + async fn estimate_input_fee(&self, input_amount: CdkAmount) -> Result { + let proofs = self.cashu_wallet.get_unspent_proofs().await.map_err(|e| { + TrustedError::WalletOperationFailed(format!("Failed to get unspent proofs: {e}")) + })?; + + let mut counts_by_keyset = HashMap::new(); + for proof in proofs { + *counts_by_keyset.entry(proof.keyset_id).or_insert(0_u64) += 1; + } + + let mut fee = Amount::ZERO; + for (keyset_id, proof_count) in counts_by_keyset { + let keyset_fee = + self.cashu_wallet.calculate_fee(proof_count, keyset_id).await.map_err(|e| { + TrustedError::WalletOperationFailed(format!( + "Failed to calculate input fee: {e}" + )) + })?; + fee = fee.saturating_add(convert_amount(keyset_fee, &self.unit)?); + } + + let active_keyset = self.cashu_wallet.get_active_keyset().await.map_err(|e| { + TrustedError::WalletOperationFailed(format!("Failed to get active keyset: {e}")) + })?; + let fee_and_amounts = + self.cashu_wallet.get_keyset_fees_and_amounts_by_id(active_keyset.id).await.map_err( + |e| { + TrustedError::WalletOperationFailed(format!( + "Failed to get keyset fee amounts: {e}" + )) + }, + )?; + let output_count = input_amount.split(&fee_and_amounts).map_err(|e| { + TrustedError::WalletOperationFailed(format!( + "Failed to calculate melt output count: {e}" + )) + })?; + let output_fee = self + .cashu_wallet + .calculate_fee(output_count.len() as u64, active_keyset.id) + .await + .map_err(|e| { + TrustedError::WalletOperationFailed(format!("Failed to calculate output fee: {e}")) + })?; + fee = fee.saturating_add(convert_amount(output_fee, &self.unit)?); + + Ok(fee) + } + pub(crate) async fn await_payment_success(&self) { let mut flag = self.payment_success_flag.clone(); flag.mark_unchanged(); diff --git a/orange-sdk/src/trusted_wallet/mod.rs b/orange-sdk/src/trusted_wallet/mod.rs index 5f76082..ce2abb7 100644 --- a/orange-sdk/src/trusted_wallet/mod.rs +++ b/orange-sdk/src/trusted_wallet/mod.rs @@ -127,6 +127,12 @@ impl graduated_rebalancer::TrustedWallet for Box::pin(async move { self.0.pay(method, amount).await }) } + fn estimate_fee( + &self, method: PaymentMethod, amount: Amount, + ) -> Pin> + Send + '_>> { + Box::pin(async move { self.0.estimate_fee(method, amount).await }) + } + fn await_payment_success( &self, payment_hash: [u8; 32], ) -> Pin> + Send + '_>> { diff --git a/orange-sdk/tests/integration_tests.rs b/orange-sdk/tests/integration_tests.rs index b432052..552151c 100644 --- a/orange-sdk/tests/integration_tests.rs +++ b/orange-sdk/tests/integration_tests.rs @@ -310,21 +310,24 @@ async fn test_sweep_to_ln() { let expect_amt = intermediate_amt.saturating_add(recv_amt); - let event = wait_next_event(&wallet).await; - match event { + let received_rebalance_amount = match wait_next_event(&wallet).await { Event::PaymentReceived { payment_id, amount_msat, lsp_fee_msats, .. } => { assert!(matches!(payment_id, orange_sdk::PaymentId::SelfCustodial(_))); - assert!(lsp_fee_msats.is_some()); - assert_eq!(amount_msat, expect_amt.milli_sats() - lsp_fee_msats.unwrap()); + let lsp_fee_msats = lsp_fee_msats.expect("rebalance receive should pay LSP fee"); + assert!( + amount_msat + lsp_fee_msats <= expect_amt.milli_sats(), + "rebalance receive should not exceed trusted balance after fees" + ); + amount_msat + lsp_fee_msats }, e => panic!("Expected RebalanceSuccessful event, got {e:?}"), - } + }; let event = wait_next_event(&wallet).await; match event { Event::RebalanceSuccessful { amount_msat, fee_msat, .. } => { assert!(fee_msat > 0); - assert_eq!(amount_msat, expect_amt.milli_sats()); + assert_eq!(amount_msat, received_rebalance_amount); }, e => panic!("Expected RebalanceSuccessful event, got {e:?}"), } @@ -786,9 +789,12 @@ async fn test_receive_to_onchain_with_channel() { // check we received on-chain, should be pending // wait for payment success - test_utils::wait_for_condition("pending balance to update", || async { - // onchain balance is always listed as pending until we splice it into the channel. + test_utils::wait_for_condition("onchain receive to appear", || async { wallet.get_balance().await.unwrap().pending_balance == recv_amt + || wallet.list_transactions().await.unwrap().iter().any(|tx| { + tx.payment_type == PaymentType::IncomingOnChain { txid: Some(sent_txid) } + && tx.amount == Some(recv_amt) + }) }) .await; @@ -886,8 +892,12 @@ async fn test_concurrent_splice_in_and_out_preserve_pending_events() { generate_blocks(&bitcoind, &electrsd, 6).await; wallet.sync_ln_wallet().unwrap(); - test_utils::wait_for_condition("pending balance to update", || async { + test_utils::wait_for_condition("onchain receive to appear", || async { wallet.get_balance().await.unwrap().pending_balance == recv_amt + || wallet.list_transactions().await.unwrap().iter().any(|tx| { + tx.payment_type == PaymentType::IncomingOnChain { txid: Some(sent_txid) } + && tx.amount == Some(recv_amt) + }) }) .await; diff --git a/orange-sdk/tests/test_utils.rs b/orange-sdk/tests/test_utils.rs index 5b299b5..87fd7e5 100644 --- a/orange-sdk/tests/test_utils.rs +++ b/orange-sdk/tests/test_utils.rs @@ -2,7 +2,7 @@ use bitcoin_payment_instructions::amount::Amount; #[cfg(feature = "_cashu-tests")] -use cdk::mint::{MintBuilder, MintMeltLimits}; +use cdk::mint::{MintBuilder, MintMeltLimits, UnitConfig}; #[cfg(feature = "_cashu-tests")] use cdk::types::FeeReserve; #[cfg(feature = "_cashu-tests")] @@ -527,6 +527,12 @@ async fn build_test_nodes() -> TestParams { let mut mint_seed: [u8; 64] = [0; 64]; rand::thread_rng().fill_bytes(&mut mint_seed); let mut builder = MintBuilder::new(mem_db.clone()); + builder + .configure_unit( + orange_sdk::CurrencyUnit::Sat, + UnitConfig { input_fee_ppk: 1, ..Default::default() }, + ) + .unwrap(); builder .add_payment_processor(