Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions crates/socket-patch-cli/src/commands/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ pub struct ApplyArgs {
}

pub async fn run(args: ApplyArgs) -> i32 {
let api_token = std::env::var("SOCKET_API_TOKEN").ok();
let org_slug = std::env::var("SOCKET_ORG_SLUG").ok();
let (telemetry_client, _) = get_api_client_from_env(None).await;
let api_token = telemetry_client.api_token().cloned();
let org_slug = telemetry_client.org_slug().cloned();

let manifest_path = if Path::new(&args.manifest_path).is_absolute() {
PathBuf::from(&args.manifest_path)
Expand Down Expand Up @@ -156,7 +157,7 @@ async fn apply_patches_inner(
println!("Downloading {} missing blob(s)...", missing_blobs.len());
}

let (client, _) = get_api_client_from_env(None);
let (client, _) = get_api_client_from_env(None).await;
let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await;

if !args.silent {
Expand Down
19 changes: 5 additions & 14 deletions crates/socket-patch-cli/src/commands/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,12 @@ pub async fn run(args: GetArgs) -> i32 {
std::env::set_var("SOCKET_API_TOKEN", token);
}

let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref());
let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await;

if !use_public_proxy && args.org.is_none() {
eprintln!("Error: --org is required when using SOCKET_API_TOKEN. Provide an organization slug.");
return 1;
}

let effective_org_slug = if use_public_proxy {
let effective_org_slug: Option<&str> = if use_public_proxy {
None
} else {
args.org.as_deref()
None // org slug is already stored in the client
};

// Determine identifier type
Expand Down Expand Up @@ -517,12 +512,8 @@ async fn save_and_apply_patch(
_org_slug: Option<&str>,
) -> i32 {
// For UUID mode, fetch and save
let (api_client, _) = get_api_client_from_env(args.org.as_deref());
let effective_org = if args.org.is_some() {
args.org.as_deref()
} else {
None
};
let (api_client, _) = get_api_client_from_env(args.org.as_deref()).await;
let effective_org: Option<&str> = None; // org slug is already stored in the client

let patch = match api_client.fetch_patch(effective_org, uuid).await {
Ok(Some(p)) => p,
Expand Down
6 changes: 4 additions & 2 deletions crates/socket-patch-cli/src/commands/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ pub struct RemoveArgs {
}

pub async fn run(args: RemoveArgs) -> i32 {
let api_token = std::env::var("SOCKET_API_TOKEN").ok();
let org_slug = std::env::var("SOCKET_ORG_SLUG").ok();
let (telemetry_client, _) =
socket_patch_core::api::client::get_api_client_from_env(None).await;
let api_token = telemetry_client.api_token().cloned();
let org_slug = telemetry_client.org_slug().cloned();

let manifest_path = if Path::new(&args.manifest_path).is_absolute() {
PathBuf::from(&args.manifest_path)
Expand Down
2 changes: 1 addition & 1 deletion crates/socket-patch-cli/src/commands/repair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn repair_inner(args: &RepairArgs, manifest_path: &Path) -> Result<(), Str
}
} else {
println!("\nDownloading missing blobs...");
let (client, _) = get_api_client_from_env(None);
let (client, _) = get_api_client_from_env(None).await;
let fetch_result = fetch_missing_blobs(&manifest, &blobs_path, &client, None).await;
println!("{}", format_fetch_result(&fetch_result));
}
Expand Down
29 changes: 12 additions & 17 deletions crates/socket-patch-cli/src/commands/rollback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,29 +136,24 @@ async fn get_missing_before_blobs(
}

pub async fn run(args: RollbackArgs) -> i32 {
let api_token = args
.api_token
.clone()
.or_else(|| std::env::var("SOCKET_API_TOKEN").ok());
let org_slug = args
.org
.clone()
.or_else(|| std::env::var("SOCKET_ORG_SLUG").ok());

// Validate one-off requires identifier
if args.one_off && args.identifier.is_none() {
eprintln!("Error: --one-off requires an identifier (UUID or PURL)");
return 1;
}

// Override env vars if CLI options provided
// Override env vars if CLI options provided (before building client)
if let Some(ref url) = args.api_url {
std::env::set_var("SOCKET_API_URL", url);
}
if let Some(ref token) = args.api_token {
std::env::set_var("SOCKET_API_TOKEN", token);
}

let (telemetry_client, _) = get_api_client_from_env(args.org.as_deref()).await;
let api_token = telemetry_client.api_token().cloned();
let org_slug = telemetry_client.org_slug().cloned();

// Validate one-off requires identifier
if args.one_off && args.identifier.is_none() {
eprintln!("Error: --one-off requires an identifier (UUID or PURL)");
return 1;
}

// Handle one-off mode
if args.one_off {
// One-off mode not fully implemented yet - placeholder
Expand Down Expand Up @@ -314,7 +309,7 @@ async fn rollback_patches_inner(
println!("Downloading {} missing blob(s)...", missing_blobs.len());
}

let (client, _) = get_api_client_from_env(None);
let (client, _) = get_api_client_from_env(None).await;
let fetch_result = fetch_blobs_by_hash(&missing_blobs, &blobs_path, &client, None).await;

if !args.silent {
Expand Down
11 changes: 3 additions & 8 deletions crates/socket-patch-cli/src/commands/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,12 @@ pub async fn run(args: ScanArgs) -> i32 {
std::env::set_var("SOCKET_API_TOKEN", token);
}

let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref());
let (api_client, use_public_proxy) = get_api_client_from_env(args.org.as_deref()).await;

if !use_public_proxy && args.org.is_none() {
eprintln!("Error: --org is required when using SOCKET_API_TOKEN. Provide an organization slug.");
return 1;
}

let effective_org_slug = if use_public_proxy {
let effective_org_slug: Option<&str> = if use_public_proxy {
None
} else {
args.org.as_deref()
None // org slug is already stored in the client
};

let crawler_options = CrawlerOptions {
Expand Down
83 changes: 78 additions & 5 deletions crates/socket-patch-core/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ impl ApiClient {
}
}

/// Returns the API token, if set.
pub fn api_token(&self) -> Option<&String> {
self.api_token.as_ref()
}

/// Returns the org slug, if set.
pub fn org_slug(&self) -> Option<&String> {
self.org_slug.as_ref()
}

// ── Internal helpers ──────────────────────────────────────────────

/// Internal GET that deserialises JSON. Returns `Ok(None)` on 404.
Expand Down Expand Up @@ -397,6 +407,46 @@ impl ApiClient {
})
}

/// Fetch organizations accessible to the current API token.
pub async fn fetch_organizations(
&self,
) -> Result<Vec<crate::api::types::OrganizationInfo>, ApiError> {
let path = "/v0/organizations";
match self
.get_json::<crate::api::types::OrganizationsResponse>(path)
.await?
{
Some(resp) => Ok(resp.organizations.into_values().collect()),
None => Ok(Vec::new()),
}
}

/// Resolve the org slug from the API token by querying `/v0/organizations`.
///
/// If there is exactly one org, returns its slug.
/// If there are multiple, picks the first and prints a warning.
/// If there are none, returns an error.
pub async fn resolve_org_slug(&self) -> Result<String, ApiError> {
let orgs = self.fetch_organizations().await?;
match orgs.len() {
0 => Err(ApiError::Other(
"No organizations found for this API token.".into(),
)),
1 => Ok(orgs.into_iter().next().unwrap().slug),
_ => {
let slugs: Vec<_> = orgs.iter().map(|o| o.slug.as_str()).collect();
let first = orgs[0].slug.clone();
eprintln!(
"Multiple organizations found: {}. Using \"{}\". \
Pass --org to select a different one.",
slugs.join(", "),
first
);
Ok(first)
}
}
}

/// Fetch a blob by its SHA-256 hash.
///
/// Returns the raw binary content, or `Ok(None)` if not found.
Expand Down Expand Up @@ -490,6 +540,10 @@ impl ApiClient {
/// API proxy which provides free access to free-tier patches without
/// authentication.
///
/// When `SOCKET_API_TOKEN` is set but no org slug is provided (neither via
/// argument nor `SOCKET_ORG_SLUG` env var), the function will attempt to
/// auto-resolve the org slug by querying `GET /v0/organizations`.
///
/// # Environment variables
///
/// | Variable | Purpose |
Expand All @@ -500,7 +554,7 @@ impl ApiClient {
/// | `SOCKET_ORG_SLUG` | Organization slug |
///
/// Returns `(client, use_public_proxy)`.
pub fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
pub async fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
let api_token = std::env::var("SOCKET_API_TOKEN").ok();
let resolved_org_slug = org_slug
.map(String::from)
Expand All @@ -524,11 +578,30 @@ pub fn get_api_client_from_env(org_slug: Option<&str>) -> (ApiClient, bool) {
let api_url =
std::env::var("SOCKET_API_URL").unwrap_or_else(|_| DEFAULT_SOCKET_API_URL.to_string());

// Auto-resolve org slug if not provided
let final_org_slug = if resolved_org_slug.is_some() {
resolved_org_slug
} else {
let temp_client = ApiClient::new(ApiClientOptions {
api_url: api_url.clone(),
api_token: api_token.clone(),
use_public_proxy: false,
org_slug: None,
});
match temp_client.resolve_org_slug().await {
Ok(slug) => Some(slug),
Err(e) => {
eprintln!("Warning: Could not auto-detect organization: {e}");
None
}
}
};

let client = ApiClient::new(ApiClientOptions {
api_url,
api_token,
use_public_proxy: false,
org_slug: resolved_org_slug,
org_slug: final_org_slug,
});
(client, false)
}
Expand Down Expand Up @@ -714,11 +787,11 @@ mod tests {
assert_eq!(info.title, "Test vulnerability");
}

#[test]
fn test_get_api_client_from_env_no_token() {
#[tokio::test]
async fn test_get_api_client_from_env_no_token() {
// Clear token to ensure public proxy mode
std::env::remove_var("SOCKET_API_TOKEN");
let (client, is_public) = get_api_client_from_env(None);
let (client, is_public) = get_api_client_from_env(None).await;
assert!(is_public);
assert!(client.use_public_proxy);
}
Expand Down
16 changes: 16 additions & 0 deletions crates/socket-patch-core/src/api/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Organization info returned by the `/v0/organizations` endpoint.
#[derive(Debug, Clone, Deserialize)]
pub struct OrganizationInfo {
pub id: String,
pub name: Option<String>,
pub image: Option<String>,
pub plan: String,
pub slug: String,
}

/// Response from `GET /v0/organizations`.
#[derive(Debug, Clone, Deserialize)]
pub struct OrganizationsResponse {
pub organizations: HashMap<String, OrganizationInfo>,
}

/// Full patch response with blob content (from view endpoint).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down
Loading