Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/refresh-long-running-helper-tokens.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@googleworkspace/cli": patch
---

Refresh OAuth access tokens for long-running Gmail watch and Workspace Events subscribe helpers before each Pub/Sub and Gmail request.
80 changes: 80 additions & 0 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,73 @@ enum Credential {
ServiceAccount(yup_oauth2::ServiceAccountKey),
}

/// Fetches access tokens for a fixed set of scopes.
///
/// Long-running helpers use this trait so they can request a fresh token before
/// each API call instead of holding a single token string until it expires.
#[async_trait::async_trait]
pub trait AccessTokenProvider: Send + Sync {
async fn access_token(&self) -> anyhow::Result<String>;
}

/// A token provider backed by [`get_token`].
///
/// This keeps the scope list in one place so call sites can ask for a fresh
/// token whenever they need to make another request.
#[derive(Debug, Clone)]
pub struct ScopedTokenProvider {
scopes: Vec<String>,
}

impl ScopedTokenProvider {
pub fn new(scopes: &[&str]) -> Self {
Self {
scopes: scopes.iter().map(|scope| (*scope).to_string()).collect(),
}
}
}

#[async_trait::async_trait]
impl AccessTokenProvider for ScopedTokenProvider {
async fn access_token(&self) -> anyhow::Result<String> {
let scopes: Vec<&str> = self.scopes.iter().map(String::as_str).collect();
get_token(&scopes).await
}
}

pub fn token_provider(scopes: &[&str]) -> ScopedTokenProvider {
ScopedTokenProvider::new(scopes)
}

/// A fake [`AccessTokenProvider`] for tests that returns tokens from a queue.
#[cfg(test)]
pub struct FakeTokenProvider {
tokens: std::sync::Arc<tokio::sync::Mutex<std::collections::VecDeque<String>>>,
}

#[cfg(test)]
impl FakeTokenProvider {
pub fn new(tokens: impl IntoIterator<Item = &'static str>) -> Self {
Self {
tokens: std::sync::Arc::new(tokio::sync::Mutex::new(
tokens.into_iter().map(|t| t.to_string()).collect(),
)),
}
}
}

#[cfg(test)]
#[async_trait::async_trait]
impl AccessTokenProvider for FakeTokenProvider {
async fn access_token(&self) -> anyhow::Result<String> {
self.tokens
.lock()
.await
.pop_front()
.ok_or_else(|| anyhow::anyhow!("no test token remaining"))
}
}

/// Builds an OAuth2 authenticator and returns an access token.
///
/// Tries credentials in order:
Expand Down Expand Up @@ -544,6 +611,19 @@ mod tests {
assert_eq!(result.unwrap(), "my-test-token");
}

#[tokio::test]
#[serial_test::serial]
async fn test_scoped_token_provider_uses_get_token() {
let _token_guard = EnvVarGuard::set("GOOGLE_WORKSPACE_CLI_TOKEN", "provider-token");
let provider = token_provider(&["https://www.googleapis.com/auth/drive"]);

let first = provider.access_token().await.unwrap();
let second = provider.access_token().await.unwrap();

assert_eq!(first, "provider-token");
assert_eq!(second, "provider-token");
}

#[tokio::test]
async fn test_load_credentials_encrypted_file() {
// Simulate an encrypted credentials file
Expand Down
164 changes: 143 additions & 21 deletions src/helpers/events/subscribe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::*;
use crate::auth::AccessTokenProvider;
use std::path::PathBuf;

const PUBSUB_API_BASE: &str = "https://pubsub.googleapis.com/v1";

#[derive(Debug, Clone, Default, Builder)]
#[builder(setter(into))]
pub struct SubscribeConfig {
Expand Down Expand Up @@ -110,6 +113,7 @@ pub(super) async fn handle_subscribe(
}

let client = crate::client::build_client()?;
let pubsub_token_provider = auth::token_provider(&[PUBSUB_SCOPE]);

// Get Pub/Sub token
let pubsub_token = auth::get_token(&[PUBSUB_SCOPE])
Expand Down Expand Up @@ -248,29 +252,38 @@ pub(super) async fn handle_subscribe(
};

// Pull loop
let result = pull_loop(&client, &pubsub_token, &pubsub_subscription, config.clone()).await;
let result = pull_loop(
&client,
&pubsub_token_provider,
&pubsub_subscription,
config.clone(),
PUBSUB_API_BASE,
)
.await;

// On exit, print reconnection info or cleanup
if created_resources {
if config.cleanup {
eprintln!("\nCleaning up Pub/Sub resources...");
// Delete Pub/Sub subscription
let _ = client
.delete(format!(
"https://pubsub.googleapis.com/v1/{pubsub_subscription}"
))
.bearer_auth(&pubsub_token)
.send()
.await;
// Delete Pub/Sub topic
if let Some(ref topic) = topic_name {
if let Ok(pubsub_token) = pubsub_token_provider.access_token().await {
let _ = client
.delete(format!("https://pubsub.googleapis.com/v1/{topic}"))
.delete(format!("{PUBSUB_API_BASE}/{pubsub_subscription}"))
.bearer_auth(&pubsub_token)
.send()
.await;
// Delete Pub/Sub topic
if let Some(ref topic) = topic_name {
let _ = client
.delete(format!("{PUBSUB_API_BASE}/{topic}"))
.bearer_auth(&pubsub_token)
.send()
.await;
}
eprintln!("Cleanup complete.");
} else {
eprintln!("Warning: failed to refresh token for cleanup. Resources may need manual deletion.");
}
eprintln!("Cleanup complete.");
} else {
eprintln!("\n--- Reconnection Info ---");
eprintln!(
Expand Down Expand Up @@ -301,21 +314,24 @@ pub(super) async fn handle_subscribe(
/// Pulls messages from a Pub/Sub subscription in a loop.
async fn pull_loop(
client: &reqwest::Client,
token: &str,
token_provider: &dyn auth::AccessTokenProvider,
subscription: &str,
config: SubscribeConfig,
pubsub_api_base: &str,
) -> Result<(), GwsError> {
let mut file_counter: u64 = 0;
loop {
let token = token_provider
.access_token()
.await
.map_err(|e| GwsError::Auth(format!("Failed to get Pub/Sub token: {e}")))?;
let pull_body = json!({
"maxMessages": config.max_messages,
});

let pull_future = client
.post(format!(
"https://pubsub.googleapis.com/v1/{subscription}:pull"
))
.bearer_auth(token)
.post(format!("{pubsub_api_base}/{subscription}:pull"))
.bearer_auth(&token)
.header("Content-Type", "application/json")
.json(&pull_body)
.timeout(std::time::Duration::from_secs(config.poll_interval.max(10)))
Expand Down Expand Up @@ -379,10 +395,8 @@ async fn pull_loop(
});

let _ = client
.post(format!(
"https://pubsub.googleapis.com/v1/{subscription}:acknowledge"
))
.bearer_auth(token)
.post(format!("{pubsub_api_base}/{subscription}:acknowledge"))
.bearer_auth(&token)
.header("Content-Type", "application/json")
.json(&ack_body)
.send()
Expand Down Expand Up @@ -526,6 +540,76 @@ fn derive_slug_from_event_types(event_types: &[&str]) -> String {
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::FakeTokenProvider;
use base64::Engine as _;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;

async fn spawn_subscribe_server() -> (
String,
Arc<Mutex<Vec<(String, String)>>>,
tokio::task::JoinHandle<()>,
) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let requests = Arc::new(Mutex::new(Vec::new()));
let recorded_requests = Arc::clone(&requests);

let handle = tokio::spawn(async move {
for _ in 0..2 {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0_u8; 8192];
let bytes_read = stream.read(&mut buf).await.unwrap();
let request = String::from_utf8_lossy(&buf[..bytes_read]);
let path = request
.lines()
.next()
.and_then(|line| line.split_whitespace().nth(1))
.unwrap_or("")
.to_string();
let auth_header = request
.lines()
.find(|line| line.to_ascii_lowercase().starts_with("authorization:"))
.unwrap_or("")
.trim()
.to_string();
recorded_requests
.lock()
.await
.push((path.clone(), auth_header));

let body = match path.as_str() {
"/v1/projects/test/subscriptions/demo:pull" => json!({
"receivedMessages": [{
"ackId": "ack-1",
"message": {
"attributes": {
"type": "google.workspace.chat.message.v1.created",
"source": "//chat/spaces/A"
},
"data": base64::engine::general_purpose::STANDARD
.encode(json!({ "id": "evt-1" }).to_string())
}
}]
})
.to_string(),
"/v1/projects/test/subscriptions/demo:acknowledge" => json!({}).to_string(),
other => panic!("unexpected request path: {other}"),
};

let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nConnection: close\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
stream.write_all(response.as_bytes()).await.unwrap();
}
});

(format!("http://{addr}/v1"), requests, handle)
}

fn make_matches_subscribe(args: &[&str]) -> ArgMatches {
let cmd = Command::new("test")
Expand Down Expand Up @@ -753,4 +837,42 @@ mod tests {
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("--project is required"));
}

#[tokio::test]
async fn test_pull_loop_refreshes_pubsub_token_between_requests() {
let client = reqwest::Client::new();
let token_provider = FakeTokenProvider::new(["pubsub-token"]);
let (pubsub_base, requests, server) = spawn_subscribe_server().await;
let config = SubscribeConfigBuilder::default()
.subscription(Some(SubscriptionName(
"projects/test/subscriptions/demo".to_string(),
)))
.max_messages(1_u32)
.poll_interval(1_u64)
.once(true)
.build()
.unwrap();

pull_loop(
&client,
&token_provider,
"projects/test/subscriptions/demo",
config,
&pubsub_base,
)
.await
.unwrap();

server.await.unwrap();

let requests = requests.lock().await;
assert_eq!(requests.len(), 2);
assert_eq!(requests[0].0, "/v1/projects/test/subscriptions/demo:pull");
assert_eq!(requests[0].1, "authorization: Bearer pubsub-token");
assert_eq!(
requests[1].0,
"/v1/projects/test/subscriptions/demo:acknowledge"
);
assert_eq!(requests[1].1, "authorization: Bearer pubsub-token");
}
}
Loading
Loading