diff --git a/crates/socket-patch-cli/src/commands/apply.rs b/crates/socket-patch-cli/src/commands/apply.rs index 24aeb7d..8b2ad29 100644 --- a/crates/socket-patch-cli/src/commands/apply.rs +++ b/crates/socket-patch-cli/src/commands/apply.rs @@ -49,8 +49,9 @@ pub struct ApplyArgs { } pub async fn run(args: ApplyArgs) -> i32 { - let api_token = std::env::var("SOCKET_API_TOKEN").ok(); - let org_slug = std::env::var("SOCKET_ORG_SLUG").ok(); + let (telemetry_client, _) = get_api_client_from_env(None).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); let manifest_path = if Path::new(&args.manifest_path).is_absolute() { PathBuf::from(&args.manifest_path) @@ -156,7 +157,7 @@ async fn apply_patches_inner( println!("Downloading {} missing blob(s)...", missing_blobs.len()); } - let (client, _) = get_api_client_from_env(None); + let (client, _) = get_api_client_from_env(None).await; let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await; if !args.silent { diff --git a/crates/socket-patch-cli/src/commands/get.rs b/crates/socket-patch-cli/src/commands/get.rs index 30987be..1b68096 100644 --- a/crates/socket-patch-cli/src/commands/get.rs +++ b/crates/socket-patch-cli/src/commands/get.rs @@ -121,17 +121,12 @@ pub async fn run(args: GetArgs) -> i32 { std::env::set_var("SOCKET_API_TOKEN", token); } - let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()); + let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await; - if !use_public_proxy && args.org.is_none() { - eprintln!("Error: --org is required when using SOCKET_API_TOKEN. Provide an organization slug."); - return 1; - } - - let effective_org_slug = if use_public_proxy { + let effective_org_slug: Option<&str> = if use_public_proxy { None } else { - args.org.as_deref() + None // org slug is already stored in the client }; // Determine identifier type @@ -517,12 +512,8 @@ async fn save_and_apply_patch( _org_slug: Option<&str>, ) -> i32 { // For UUID mode, fetch and save - let (api_client, _) = get_api_client_from_env(args.org.as_deref()); - let effective_org = if args.org.is_some() { - args.org.as_deref() - } else { - None - }; + let (api_client, _) = get_api_client_from_env(args.org.as_deref()).await; + let effective_org: Option<&str> = None; // org slug is already stored in the client let patch = match api_client.fetch_patch(effective_org, uuid).await { Ok(Some(p)) => p, diff --git a/crates/socket-patch-cli/src/commands/remove.rs b/crates/socket-patch-cli/src/commands/remove.rs index f05379a..444fa15 100644 --- a/crates/socket-patch-cli/src/commands/remove.rs +++ b/crates/socket-patch-cli/src/commands/remove.rs @@ -35,8 +35,10 @@ pub struct RemoveArgs { } pub async fn run(args: RemoveArgs) -> i32 { - let api_token = std::env::var("SOCKET_API_TOKEN").ok(); - let org_slug = std::env::var("SOCKET_ORG_SLUG").ok(); + let (telemetry_client, _) = + socket_patch_core::api::client::get_api_client_from_env(None).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); let manifest_path = if Path::new(&args.manifest_path).is_absolute() { PathBuf::from(&args.manifest_path) diff --git a/crates/socket-patch-cli/src/commands/repair.rs b/crates/socket-patch-cli/src/commands/repair.rs index 581b608..c783cd5 100644 --- a/crates/socket-patch-cli/src/commands/repair.rs +++ b/crates/socket-patch-cli/src/commands/repair.rs @@ -78,7 +78,7 @@ async fn repair_inner(args: &RepairArgs, manifest_path: &Path) -> Result<(), Str } } else { println!("\nDownloading missing blobs..."); - let (client, _) = get_api_client_from_env(None); + let (client, _) = get_api_client_from_env(None).await; let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await; println!("{}", format_fetch_result(&fetch_result)); } diff --git a/crates/socket-patch-cli/src/commands/rollback.rs b/crates/socket-patch-cli/src/commands/rollback.rs index c09f161..e19817c 100644 --- a/crates/socket-patch-cli/src/commands/rollback.rs +++ b/crates/socket-patch-cli/src/commands/rollback.rs @@ -136,22 +136,7 @@ async fn get_missing_before_blobs( } pub async fn run(args: RollbackArgs) -> i32 { - let api_token = args - .api_token - .clone() - .or_else(|| std::env::var("SOCKET_API_TOKEN").ok()); - let org_slug = args - .org - .clone() - .or_else(|| std::env::var("SOCKET_ORG_SLUG").ok()); - - // Validate one-off requires identifier - if args.one_off && args.identifier.is_none() { - eprintln!("Error: --one-off requires an identifier (UUID or PURL)"); - return 1; - } - - // Override env vars if CLI options provided + // Override env vars if CLI options provided (before building client) if let Some(ref url) = args.api_url { std::env::set_var("SOCKET_API_URL", url); } @@ -159,6 +144,16 @@ pub async fn run(args: RollbackArgs) -> i32 { std::env::set_var("SOCKET_API_TOKEN", token); } + let (telemetry_client, _) = get_api_client_from_env(args.org.as_deref()).await; + let api_token = telemetry_client.api_token().cloned(); + let org_slug = telemetry_client.org_slug().cloned(); + + // Validate one-off requires identifier + if args.one_off && args.identifier.is_none() { + eprintln!("Error: --one-off requires an identifier (UUID or PURL)"); + return 1; + } + // Handle one-off mode if args.one_off { // One-off mode not fully implemented yet - placeholder @@ -314,7 +309,7 @@ async fn rollback_patches_inner( println!("Downloading {} missing blob(s)...", missing_blobs.len()); } - let (client, _) = get_api_client_from_env(None); + let (client, _) = get_api_client_from_env(None).await; let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await; if !args.silent { diff --git a/crates/socket-patch-cli/src/commands/scan.rs b/crates/socket-patch-cli/src/commands/scan.rs index 81478aa..3d5ef49 100644 --- a/crates/socket-patch-cli/src/commands/scan.rs +++ b/crates/socket-patch-cli/src/commands/scan.rs @@ -51,17 +51,12 @@ pub async fn run(args: ScanArgs) -> i32 { std::env::set_var("SOCKET_API_TOKEN", token); } - let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()); + let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await; - if !use_public_proxy && args.org.is_none() { - eprintln!("Error: --org is required when using SOCKET_API_TOKEN. Provide an organization slug."); - return 1; - } - - let effective_org_slug = if use_public_proxy { + let effective_org_slug: Option<&str> = if use_public_proxy { None } else { - args.org.as_deref() + None // org slug is already stored in the client }; let crawler_options = CrawlerOptions { diff --git a/crates/socket-patch-core/src/api/client.rs b/crates/socket-patch-core/src/api/client.rs index e1757e8..72a0fae 100644 --- a/crates/socket-patch-core/src/api/client.rs +++ b/crates/socket-patch-core/src/api/client.rs @@ -111,6 +111,16 @@ impl ApiClient { } } + /// Returns the API token, if set. + pub fn api_token(&self) -> Option<&String> { + self.api_token.as_ref() + } + + /// Returns the org slug, if set. + pub fn org_slug(&self) -> Option<&String> { + self.org_slug.as_ref() + } + // ── Internal helpers ────────────────────────────────────────────── /// Internal GET that deserialises JSON. Returns `Ok(None)` on 404. @@ -397,6 +407,46 @@ impl ApiClient { }) } + /// Fetch organizations accessible to the current API token. + pub async fn fetch_organizations( + &self, + ) -> Result, ApiError> { + let path = "/v0/organizations"; + match self + .get_json::(path) + .await? + { + Some(resp) => Ok(resp.organizations.into_values().collect()), + None => Ok(Vec::new()), + } + } + + /// Resolve the org slug from the API token by querying `/v0/organizations`. + /// + /// If there is exactly one org, returns its slug. + /// If there are multiple, picks the first and prints a warning. + /// If there are none, returns an error. + pub async fn resolve_org_slug(&self) -> Result { + let orgs = self.fetch_organizations().await?; + match orgs.len() { + 0 => Err(ApiError::Other( + "No organizations found for this API token.".into(), + )), + 1 => Ok(orgs.into_iter().next().unwrap().slug), + _ => { + let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect(); + let first = orgs[0].slug.clone(); + eprintln!( + "Multiple organizations found: {}. Using \"{}\". \ + Pass --org to select a different one.", + slugs.join(", "), + first + ); + Ok(first) + } + } + } + /// Fetch a blob by its SHA-256 hash. /// /// Returns the raw binary content, or `Ok(None)` if not found. @@ -490,6 +540,10 @@ impl ApiClient { /// API proxy which provides free access to free-tier patches without /// authentication. /// +/// When `SOCKET_API_TOKEN` is set but no org slug is provided (neither via +/// argument nor `SOCKET_ORG_SLUG` env var), the function will attempt to +/// auto-resolve the org slug by querying `GET /v0/organizations`. +/// /// # Environment variables /// /// | Variable | Purpose | @@ -500,7 +554,7 @@ impl ApiClient { /// | `SOCKET_ORG_SLUG` | Organization slug | /// /// Returns `(client, use_public_proxy)`. -pub fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) { +pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) { let api_token = std::env::var("SOCKET_API_TOKEN").ok(); let resolved_org_slug = org_slug .map(String::from) @@ -524,11 +578,30 @@ pub fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) { let api_url = std::env::var("SOCKET_API_URL").unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string()); + // Auto-resolve org slug if not provided + let final_org_slug = if resolved_org_slug.is_some() { + resolved_org_slug + } else { + let temp_client = ApiClient::new(ApiClientOptions { + api_url: api_url.clone(), + api_token: api_token.clone(), + use_public_proxy: false, + org_slug: None, + }); + match temp_client.resolve_org_slug().await { + Ok(slug) => Some(slug), + Err(e) => { + eprintln!("Warning: Could not auto-detect organization: {e}"); + None + } + } + }; + let client = ApiClient::new(ApiClientOptions { api_url, api_token, use_public_proxy: false, - org_slug: resolved_org_slug, + org_slug: final_org_slug, }); (client, false) } @@ -714,11 +787,11 @@ mod tests { assert_eq!(info.title, "Test vulnerability"); } - #[test] - fn test_get_api_client_from_env_no_token() { + #[tokio::test] + async fn test_get_api_client_from_env_no_token() { // Clear token to ensure public proxy mode std::env::remove_var("SOCKET_API_TOKEN"); - let (client, is_public) = get_api_client_from_env(None); + let (client, is_public) = get_api_client_from_env(None).await; assert!(is_public); assert!(client.use_public_proxy); } diff --git a/crates/socket-patch-core/src/api/types.rs b/crates/socket-patch-core/src/api/types.rs index 688bc7c..f09c31d 100644 --- a/crates/socket-patch-core/src/api/types.rs +++ b/crates/socket-patch-core/src/api/types.rs @@ -1,6 +1,22 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +/// Organization info returned by the `/v0/organizations` endpoint. +#[derive(Debug, Clone, Deserialize)] +pub struct OrganizationInfo { + pub id: String, + pub name: Option, + pub image: Option, + pub plan: String, + pub slug: String, +} + +/// Response from `GET /v0/organizations`. +#[derive(Debug, Clone, Deserialize)] +pub struct OrganizationsResponse { + pub organizations: HashMap, +} + /// Full patch response with blob content (from view endpoint). #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")]