diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 96c319dc..7fa68dec 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -27,6 +27,8 @@ pin-project-lite = "0.2" pastey = { version = "0.2.0", optional = true } # oauth2 support oauth2 = { version = "5.0", optional = true, default-features = false } +# JWT signing for client credentials (private_key_jwt) +jsonwebtoken = { version = "9", optional = true } # for auto generate schema schemars = { version = "1.0", optional = true, features = ["chrono04"] } @@ -130,12 +132,14 @@ transport-streamable-http-server-session = [ # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] auth = ["dep:oauth2", "__reqwest", "dep:url"] +auth-client-credentials-jwt = ["auth", "dep:jsonwebtoken"] schemars = ["dep:schemars"] [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] } +url = "2.4" anyhow = "1.0" tracing-subscriber = { version = "0.3", features = [ "env-filter", @@ -251,3 +255,8 @@ path = "tests/test_custom_headers.rs" name = "test_sse_concurrent_streams" required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"] path = "tests/test_sse_concurrent_streams.rs" + +[[test]] +name = "test_client_credentials" +required-features = ["auth"] +path = "tests/test_client_credentials.rs" diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index d7dfa979..683f6880 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -93,10 +93,13 @@ pub use io::stdio; #[cfg(feature = "auth")] pub mod auth; +#[cfg(feature = "auth-client-credentials-jwt")] +pub use auth::JwtSigningAlgorithm; #[cfg(feature = "auth")] pub use auth::{ AuthClient, AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, - CredentialStore, InMemoryCredentialStore, InMemoryStateStore, ScopeUpgradeConfig, StateStore, + ClientCredentialsConfig, CredentialStore, EXTENSION_OAUTH_CLIENT_CREDENTIALS, + InMemoryCredentialStore, InMemoryStateStore, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, StoredCredentials, WWWAuthenticateParams, }; diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index b8d4f3f4..15da922f 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -288,6 +288,13 @@ pub enum AuthError { required_scope: String, upgrade_url: Option, }, + + #[error("Client credentials error: {0}")] + ClientCredentialsError(String), + + #[cfg(feature = "auth-client-credentials-jwt")] + #[error("JWT signing error: {0}")] + JwtSigningError(String), } /// oauth2 metadata @@ -364,6 +371,105 @@ type OAuthClient = oauth2::Client< >; type Credentials = (String, Option); +/// OAuth 2.0 extension identifier for client credentials flow (SEP-1046) +pub const EXTENSION_OAUTH_CLIENT_CREDENTIALS: &str = + "io.modelcontextprotocol/oauth-client-credentials"; + +/// JWT signing algorithm for private_key_jwt authentication (SEP-1046) +#[cfg(feature = "auth-client-credentials-jwt")] +#[derive(Debug, Clone, Copy)] +pub enum JwtSigningAlgorithm { + RS256, + RS384, + RS512, + ES256, + ES384, +} + +#[cfg(feature = "auth-client-credentials-jwt")] +impl JwtSigningAlgorithm { + fn to_jsonwebtoken_algorithm(self) -> jsonwebtoken::Algorithm { + match self { + JwtSigningAlgorithm::RS256 => jsonwebtoken::Algorithm::RS256, + JwtSigningAlgorithm::RS384 => jsonwebtoken::Algorithm::RS384, + JwtSigningAlgorithm::RS512 => jsonwebtoken::Algorithm::RS512, + JwtSigningAlgorithm::ES256 => jsonwebtoken::Algorithm::ES256, + JwtSigningAlgorithm::ES384 => jsonwebtoken::Algorithm::ES384, + } + } + + fn as_str(self) -> &'static str { + match self { + JwtSigningAlgorithm::RS256 => "RS256", + JwtSigningAlgorithm::RS384 => "RS384", + JwtSigningAlgorithm::RS512 => "RS512", + JwtSigningAlgorithm::ES256 => "ES256", + JwtSigningAlgorithm::ES384 => "ES384", + } + } +} + +/// Configuration for OAuth 2.0 Client Credentials flow (SEP-1046) +/// +/// This supports two authentication methods: +/// - `ClientSecret`: credentials sent in the request body +/// - `PrivateKeyJwt`: RFC 7523 signed JWT assertion (requires `auth-client-credentials-jwt` feature) +#[derive(Debug, Clone)] +pub enum ClientCredentialsConfig { + /// Client secret authentication (credentials in request body) + ClientSecret { + client_id: String, + client_secret: String, + scopes: Vec, + resource: Option, + }, + /// Private key JWT authentication (RFC 7523) + #[cfg(feature = "auth-client-credentials-jwt")] + PrivateKeyJwt { + client_id: String, + signing_key: Vec, + signing_algorithm: JwtSigningAlgorithm, + /// Override the `aud` claim in the JWT assertion; defaults to token_endpoint + token_endpoint_audience: Option, + scopes: Vec, + resource: Option, + }, +} + +impl ClientCredentialsConfig { + fn client_id(&self) -> &str { + match self { + ClientCredentialsConfig::ClientSecret { client_id, .. } => client_id, + #[cfg(feature = "auth-client-credentials-jwt")] + ClientCredentialsConfig::PrivateKeyJwt { client_id, .. } => client_id, + } + } + + fn scopes(&self) -> &[String] { + match self { + ClientCredentialsConfig::ClientSecret { scopes, .. } => scopes, + #[cfg(feature = "auth-client-credentials-jwt")] + ClientCredentialsConfig::PrivateKeyJwt { scopes, .. } => scopes, + } + } + + fn resource(&self) -> Option<&str> { + match self { + ClientCredentialsConfig::ClientSecret { resource, .. } => resource.as_deref(), + #[cfg(feature = "auth-client-credentials-jwt")] + ClientCredentialsConfig::PrivateKeyJwt { resource, .. } => resource.as_deref(), + } + } + + fn auth_method(&self) -> &str { + match self { + ClientCredentialsConfig::ClientSecret { .. } => "client_secret_post", + #[cfg(feature = "auth-client-credentials-jwt")] + ClientCredentialsConfig::PrivateKeyJwt { .. } => "private_key_jwt", + } + } +} + /// Configuration for scope upgrade behavior #[derive(Debug, Clone)] pub struct ScopeUpgradeConfig { @@ -1445,6 +1551,263 @@ impl AuthorizationManager { Some((trimmed[..end].to_string(), leading_ws + end)) } } + + // -- Client Credentials flow (SEP-1046) -- + + /// Validate that the authorization server metadata supports the requested + /// client credentials authentication method. + /// + /// For `client_secret_post`, checks `token_endpoint_auth_methods_supported`. + /// For `private_key_jwt`, additionally checks `token_endpoint_auth_signing_alg_values_supported`. + /// When the metadata field is absent, the method is permissive (assumes support). + pub fn validate_client_credentials_metadata( + &self, + config: &ClientCredentialsConfig, + ) -> Result<(), AuthError> { + let Some(metadata) = self.metadata.as_ref() else { + return Ok(()); + }; + + let required_method = config.auth_method(); + + if let Some(methods) = metadata + .additional_fields + .get("token_endpoint_auth_methods_supported") + .and_then(|v| v.as_array()) + { + if !methods.iter().any(|m| m.as_str() == Some(required_method)) { + let supported: Vec<&str> = methods.iter().filter_map(|m| m.as_str()).collect(); + return Err(AuthError::ClientCredentialsError(format!( + "Authorization server does not support auth method '{}'. Supported: {:?}", + required_method, supported + ))); + } + } + + #[cfg(feature = "auth-client-credentials-jwt")] + if let ClientCredentialsConfig::PrivateKeyJwt { + signing_algorithm, .. + } = config + { + if let Some(algs) = metadata + .additional_fields + .get("token_endpoint_auth_signing_alg_values_supported") + .and_then(|v| v.as_array()) + { + let alg_str = signing_algorithm.as_str(); + if !algs.iter().any(|a| a.as_str() == Some(alg_str)) { + let supported: Vec<&str> = algs.iter().filter_map(|a| a.as_str()).collect(); + return Err(AuthError::ClientCredentialsError(format!( + "Authorization server does not support signing algorithm '{}'. Supported: {:?}", + alg_str, supported + ))); + } + } + } + + Ok(()) + } + + /// Configure the OAuth2 client for the client credentials flow. + /// + /// Sets up the internal `BasicClient` with the token endpoint from metadata, + /// client_id, and optionally client_secret. Uses `AuthType::RequestBody` for + /// `client_secret` per SEP-1046. + pub fn configure_client_credentials( + &mut self, + config: &ClientCredentialsConfig, + ) -> Result<(), AuthError> { + let metadata = self + .metadata + .as_ref() + .ok_or(AuthError::NoAuthorizationSupport)?; + + let token_url = TokenUrl::new(metadata.token_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; + + // auth_url is required by the type but won't be used for client credentials + let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?; + + let client_id = ClientId::new(config.client_id().to_string()); + + let mut client_builder = BasicClient::new(client_id) + .set_auth_uri(auth_url) + .set_token_uri(token_url); + + match config { + ClientCredentialsConfig::ClientSecret { client_secret, .. } => { + client_builder = + client_builder.set_client_secret(ClientSecret::new(client_secret.clone())); + // SEP-1046: credentials in request body + client_builder = client_builder.set_auth_type(AuthType::RequestBody); + } + #[cfg(feature = "auth-client-credentials-jwt")] + ClientCredentialsConfig::PrivateKeyJwt { .. } => { + // For JWT, we don't set client_secret; assertion is added as extra param + // No auth type needed since we handle auth via JWT assertion params + } + } + + self.oauth_client = Some(client_builder); + Ok(()) + } + + /// Exchange client credentials for an access token (SEP-1046). + /// + /// For `ClientSecret`: sends credentials in the request body with scopes and optional resource. + /// For `PrivateKeyJwt`: additionally sends `client_assertion_type` and `client_assertion`. + pub async fn exchange_client_credentials( + &self, + config: &ClientCredentialsConfig, + ) -> Result { + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let mut request = oauth_client.exchange_client_credentials(); + + // Add scopes + for scope in config.scopes() { + request = request.add_scope(Scope::new(scope.clone())); + } + + // Add resource parameter if specified + if let Some(resource) = config.resource() { + request = request.add_extra_param("resource", resource); + } + + // For private_key_jwt, add assertion parameters + #[cfg(feature = "auth-client-credentials-jwt")] + if let ClientCredentialsConfig::PrivateKeyJwt { + client_id, + signing_key, + signing_algorithm, + token_endpoint_audience, + .. + } = config + { + let metadata = self + .metadata + .as_ref() + .ok_or(AuthError::NoAuthorizationSupport)?; + + let audience = token_endpoint_audience + .as_deref() + .unwrap_or(&metadata.token_endpoint); + + let assertion = + Self::build_jwt_assertion(client_id, audience, signing_key, *signing_algorithm)?; + + request = request.add_extra_param( + "client_assertion_type", + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + ); + request = request.add_extra_param("client_assertion", assertion); + } + + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + + let token_result = match request + .request_async(&OAuthReqwestClient(http_client)) + .await + { + Ok(token) => token, + Err(RequestTokenError::Parse(_, body)) => { + match serde_json::from_slice::(&body) { + Ok(parsed) => { + warn!( + "client credentials token exchange failed to parse completely but included a valid token response. Accepting it." + ); + parsed + } + Err(parse_err) => { + return Err(AuthError::ClientCredentialsError(format!( + "Token exchange parse error: {}", + parse_err + ))); + } + } + } + Err(e) => { + return Err(AuthError::ClientCredentialsError(format!( + "Token exchange failed: {}", + e + ))); + } + }; + + debug!("client credentials token result: {:?}", token_result); + + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + + *self.current_scopes.write().await = granted_scopes.clone(); + + let client_id = config.client_id().to_string(); + let stored = StoredCredentials { + client_id, + token_response: Some(token_result.clone()), + granted_scopes, + token_received_at: Some(Self::now_epoch_secs()), + }; + self.credential_store.save(stored).await?; + + Ok(token_result) + } + + /// Build a JWT assertion per RFC 7523 for private_key_jwt authentication. + #[cfg(feature = "auth-client-credentials-jwt")] + fn build_jwt_assertion( + client_id: &str, + audience: &str, + signing_key: &[u8], + algorithm: JwtSigningAlgorithm, + ) -> Result { + use serde_json::json; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let jti = format!("{:x}{:x}", now, rand_u64()); + + let claims = json!({ + "iss": client_id, + "sub": client_id, + "aud": audience, + "iat": now, + "exp": now + 300, // 5 minutes + "jti": jti, + }); + + let header = jsonwebtoken::Header::new(algorithm.to_jsonwebtoken_algorithm()); + let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(signing_key).or_else(|_| { + jsonwebtoken::EncodingKey::from_ec_pem(signing_key).map_err(|e| { + AuthError::JwtSigningError(format!("Failed to parse signing key: {}", e)) + }) + })?; + + jsonwebtoken::encode(&header, &claims, &encoding_key) + .map_err(|e| AuthError::JwtSigningError(format!("Failed to sign JWT: {}", e))) + } +} + +/// Simple random u64 for JWT jti claim uniqueness +#[cfg(feature = "auth-client-credentials-jwt")] +fn rand_u64() -> u64 { + use std::{ + collections::hash_map::RandomState, + hash::{BuildHasher, Hasher}, + }; + RandomState::new().build_hasher().finish() } /// oauth2 authorization session, for guiding user to complete the authorization process @@ -1863,6 +2226,41 @@ impl OAuthState { _ => None, } } + + /// Authenticate using OAuth 2.0 Client Credentials flow (SEP-1046). + /// + /// Transitions directly from `Unauthorized` to `Authorized`, skipping the + /// interactive `Session` state entirely. Discovers metadata, configures the + /// client, and exchanges credentials for an access token. + pub async fn authenticate_client_credentials( + &mut self, + config: ClientCredentialsConfig, + ) -> Result<(), AuthError> { + let OAuthState::Unauthorized(mut manager) = std::mem::replace( + self, + OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?), + ) else { + return Err(AuthError::InternalError( + "Client credentials flow requires Unauthorized state".to_string(), + )); + }; + + // Discover metadata + let metadata = manager.discover_metadata().await?; + manager.metadata = Some(metadata); + + // Validate server supports the requested auth method + manager.validate_client_credentials_metadata(&config)?; + + // Configure OAuth client + manager.configure_client_credentials(&config)?; + + // Exchange credentials for token + manager.exchange_client_credentials(&config).await?; + + *self = OAuthState::Authorized(manager); + Ok(()) + } } #[cfg(test)] @@ -2792,4 +3190,130 @@ mod tests { "expected InternalError when OAuth client is not configured, got: {err:?}" ); } + + // -- client credentials (SEP-1046) -- + + #[tokio::test] + async fn configure_client_credentials_uses_request_body_auth_for_client_secret() { + let mut mgr = manager_with_metadata(None).await; + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "my-m2m-client".to_string(), + client_secret: "super-secret".to_string(), + scopes: vec!["read".to_string()], + resource: None, + }; + mgr.configure_client_credentials(&config).unwrap(); + let oauth_client = mgr.oauth_client.as_ref().unwrap(); + assert!(matches!(oauth_client.auth_type(), AuthType::RequestBody)); + } + + #[tokio::test] + async fn configure_client_credentials_sets_correct_client_id() { + let mut mgr = manager_with_metadata(None).await; + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "my-m2m-client".to_string(), + client_secret: "super-secret".to_string(), + scopes: vec!["read".to_string()], + resource: None, + }; + mgr.configure_client_credentials(&config).unwrap(); + let oauth_client = mgr.oauth_client.as_ref().unwrap(); + assert_eq!(oauth_client.client_id().as_str(), "my-m2m-client"); + } + + #[tokio::test] + async fn configure_client_credentials_returns_error_without_metadata() { + let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap(); + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "id".to_string(), + client_secret: "secret".to_string(), + scopes: vec![], + resource: None, + }; + let err = mgr.configure_client_credentials(&config).unwrap_err(); + assert!(matches!(err, AuthError::NoAuthorizationSupport)); + } + + #[tokio::test] + async fn validate_client_credentials_metadata_rejects_unsupported_method() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["client_secret_basic"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mgr = manager_with_metadata(Some(meta)).await; + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "id".to_string(), + client_secret: "secret".to_string(), + scopes: vec![], + resource: None, + }; + let err = mgr + .validate_client_credentials_metadata(&config) + .unwrap_err(); + assert!(matches!(err, AuthError::ClientCredentialsError(_))); + } + + #[tokio::test] + async fn validate_client_credentials_metadata_accepts_supported_method() { + let mut additional_fields = HashMap::new(); + additional_fields.insert( + "token_endpoint_auth_methods_supported".to_string(), + serde_json::json!(["client_secret_post", "client_secret_basic"]), + ); + let meta = AuthorizationMetadata { + authorization_endpoint: "http://localhost/authorize".to_string(), + token_endpoint: "http://localhost/token".to_string(), + additional_fields, + ..Default::default() + }; + let mgr = manager_with_metadata(Some(meta)).await; + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "id".to_string(), + client_secret: "secret".to_string(), + scopes: vec![], + resource: None, + }; + mgr.validate_client_credentials_metadata(&config).unwrap(); + } + + #[tokio::test] + async fn validate_client_credentials_metadata_permits_when_field_absent() { + let mgr = manager_with_metadata(None).await; + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "id".to_string(), + client_secret: "secret".to_string(), + scopes: vec![], + resource: None, + }; + mgr.validate_client_credentials_metadata(&config).unwrap(); + } + + #[test] + fn client_credentials_config_returns_correct_accessor_values() { + let config = super::ClientCredentialsConfig::ClientSecret { + client_id: "test-id".to_string(), + client_secret: "test-secret".to_string(), + scopes: vec!["scope1".to_string(), "scope2".to_string()], + resource: Some("https://example.com".to_string()), + }; + assert_eq!(config.client_id(), "test-id"); + assert_eq!(config.scopes(), &["scope1", "scope2"]); + assert_eq!(config.resource(), Some("https://example.com")); + assert_eq!(config.auth_method(), "client_secret_post"); + } + + #[test] + fn extension_constant_matches_spec() { + assert_eq!( + super::EXTENSION_OAUTH_CLIENT_CREDENTIALS, + "io.modelcontextprotocol/oauth-client-credentials" + ); + } } diff --git a/crates/rmcp/tests/test_client_credentials.rs b/crates/rmcp/tests/test_client_credentials.rs new file mode 100644 index 00000000..3db98e03 --- /dev/null +++ b/crates/rmcp/tests/test_client_credentials.rs @@ -0,0 +1,197 @@ +use std::{convert::Infallible, net::SocketAddr}; + +use axum::{ + Router, + body::Body, + http::{Request, Response, StatusCode}, + routing::{get, post}, +}; +use rmcp::transport::auth::{ClientCredentialsConfig, OAuthState}; + +fn json_response(body: serde_json::Value) -> Response { + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&body).unwrap())) + .unwrap() +} + +fn json_error(status: StatusCode, body: serde_json::Value) -> Response { + Response::builder() + .status(status) + .header("content-type", "application/json") + .body(Body::from(serde_json::to_string(&body).unwrap())) + .unwrap() +} + +async fn resource_metadata_handler(req: Request) -> Result, Infallible> { + let host = req.headers().get("host").unwrap().to_str().unwrap(); + let base_url = format!("http://{}", host); + Ok(json_response(serde_json::json!({ + "resource": base_url, + "authorization_servers": [base_url], + "scopes_supported": ["read", "write"] + }))) +} + +async fn auth_server_metadata_handler(req: Request) -> Result, Infallible> { + let host = req.headers().get("host").unwrap().to_str().unwrap(); + let base_url = format!("http://{}", host); + Ok(json_response(serde_json::json!({ + "issuer": base_url, + "authorization_endpoint": format!("{}/authorize", base_url), + "token_endpoint": format!("{}/token", base_url), + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "grant_types_supported": ["client_credentials"], + "scopes_supported": ["read", "write"] + }))) +} + +async fn token_handler(req: Request) -> Result, Infallible> { + let body_bytes = axum::body::to_bytes(req.into_body(), 1024 * 64) + .await + .unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + + // Parse form-urlencoded body + let params: Vec<(String, String)> = url::form_urlencoded::parse(body_str.as_bytes()) + .into_owned() + .collect(); + + let get_param = |key: &str| -> Option { + params + .iter() + .find(|(k, _)| k == key) + .map(|(_, v)| v.clone()) + }; + + let grant_type = get_param("grant_type").unwrap_or_default(); + if grant_type != "client_credentials" { + return Ok(json_error( + StatusCode::BAD_REQUEST, + serde_json::json!({ + "error": "unsupported_grant_type", + "error_description": "Only client_credentials grant type is supported" + }), + )); + } + + let client_id = get_param("client_id").unwrap_or_default(); + if client_id != "test-m2m-client" { + return Ok(json_error( + StatusCode::UNAUTHORIZED, + serde_json::json!({ + "error": "invalid_client", + "error_description": "Unknown client_id" + }), + )); + } + + let client_secret = get_param("client_secret").unwrap_or_default(); + if client_secret != "test-m2m-secret" { + return Ok(json_error( + StatusCode::UNAUTHORIZED, + serde_json::json!({ + "error": "invalid_client", + "error_description": "Invalid client_secret" + }), + )); + } + + let scope = get_param("scope").unwrap_or_default(); + + let mut response = serde_json::json!({ + "access_token": "m2m-access-token-12345", + "token_type": "Bearer", + "expires_in": 3600 + }); + + if !scope.is_empty() { + response["scope"] = serde_json::Value::String(scope); + } + + Ok(json_response(response)) +} + +async fn start_mock_server() -> (String, SocketAddr) { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let base_url = format!("http://{}", addr); + + let app = Router::new() + .route( + "/.well-known/oauth-protected-resource", + get(resource_metadata_handler), + ) + .route( + "/.well-known/oauth-authorization-server", + get(auth_server_metadata_handler), + ) + .route("/token", post(token_handler)); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (base_url, addr) +} + +#[tokio::test] +async fn test_client_credentials_flow_client_secret() { + let (base_url, _addr) = start_mock_server().await; + + let mut oauth_state = OAuthState::new(&base_url, None).await.unwrap(); + + let config = ClientCredentialsConfig::ClientSecret { + client_id: "test-m2m-client".to_string(), + client_secret: "test-m2m-secret".to_string(), + scopes: vec!["read".to_string(), "write".to_string()], + resource: Some(base_url.clone()), + }; + + oauth_state + .authenticate_client_credentials(config) + .await + .unwrap(); + + let manager = oauth_state + .into_authorization_manager() + .expect("Should be in Authorized state"); + + let token = manager.get_access_token().await.unwrap(); + assert_eq!(token, "m2m-access-token-12345"); +} + +#[tokio::test] +async fn test_client_credentials_invalid_secret() { + let (base_url, _addr) = start_mock_server().await; + + let mut oauth_state = OAuthState::new(&base_url, None).await.unwrap(); + + let config = ClientCredentialsConfig::ClientSecret { + client_id: "test-m2m-client".to_string(), + client_secret: "wrong-secret".to_string(), + scopes: vec![], + resource: None, + }; + + let result = oauth_state.authenticate_client_credentials(config).await; + assert!(result.is_err(), "Should fail with invalid credentials"); +} + +#[tokio::test] +async fn test_client_credentials_invalid_client_id() { + let (base_url, _addr) = start_mock_server().await; + + let mut oauth_state = OAuthState::new(&base_url, None).await.unwrap(); + + let config = ClientCredentialsConfig::ClientSecret { + client_id: "unknown-client".to_string(), + client_secret: "test-m2m-secret".to_string(), + scopes: vec![], + resource: None, + }; + + let result = oauth_state.authenticate_client_credentials(config).await; + assert!(result.is_err(), "Should fail with unknown client_id"); +} diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index ea35b021..f2ac8dd7 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -57,3 +57,7 @@ path = "src/sampling_stdio.rs" [[example]] name = "clients_progress_client" path = "src/progress_client.rs" + +[[example]] +name = "clients_client_credentials" +path = "src/auth/client_credentials.rs" diff --git a/examples/clients/src/auth/client_credentials.rs b/examples/clients/src/auth/client_credentials.rs new file mode 100644 index 00000000..55aa6153 --- /dev/null +++ b/examples/clients/src/auth/client_credentials.rs @@ -0,0 +1,97 @@ +use std::env; + +use anyhow::{Context, Result}; +use rmcp::{ + ServiceExt, + model::ClientInfo, + transport::{ + StreamableHttpClientTransport, + auth::{AuthClient, ClientCredentialsConfig, OAuthState}, + streamable_http_client::StreamableHttpClientTransportConfig, + }, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +/// Example: OAuth 2.0 Client Credentials flow (SEP-1046) +/// +/// Usage: +/// cargo run -p mcp-client-examples --example clients_client_credentials -- +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let args: Vec = env::args().collect(); + let server_url = args + .get(1) + .context("Usage: ")? + .clone(); + let client_id = args + .get(2) + .context("Usage: ")? + .clone(); + let client_secret = args + .get(3) + .context("Usage: ")? + .clone(); + + tracing::info!("Connecting to MCP server at: {}", server_url); + tracing::info!("Using client_id: {}", client_id); + + // Initialize OAuth state and authenticate with client credentials + let mut oauth_state = OAuthState::new(&server_url, None) + .await + .context("Failed to initialize OAuth state")?; + + let config = ClientCredentialsConfig::ClientSecret { + client_id, + client_secret, + scopes: vec![], + resource: Some(server_url.clone()), + }; + + oauth_state + .authenticate_client_credentials(config) + .await + .context("Client credentials authentication failed")?; + + tracing::info!("Successfully authenticated with client credentials"); + + // Create authorized transport + let manager = oauth_state + .into_authorization_manager() + .context("Failed to get authorization manager")?; + let client = AuthClient::new(reqwest::Client::default(), manager); + let transport = StreamableHttpClientTransport::with_client( + client, + StreamableHttpClientTransportConfig::with_uri(server_url.as_str()), + ); + + // Connect to MCP server and list tools + let client_service = ClientInfo::default(); + let client = client_service.serve(transport).await?; + tracing::info!("Connected to MCP server"); + + match client.peer().list_all_tools().await { + Ok(tools) => { + println!("Available tools ({}):", tools.len()); + for tool in tools { + println!( + " - {} ({})", + tool.name, + tool.description.unwrap_or_default() + ); + } + } + Err(e) => { + tracing::error!("Failed to list tools: {}", e); + } + } + + Ok(()) +}