diff --git a/pgdog/src/frontend/client/query_engine/test/set_schema_sharding.rs b/pgdog/src/frontend/client/query_engine/test/set_schema_sharding.rs index 671c7f977..c7bcdf7a3 100644 --- a/pgdog/src/frontend/client/query_engine/test/set_schema_sharding.rs +++ b/pgdog/src/frontend/client/query_engine/test/set_schema_sharding.rs @@ -37,3 +37,27 @@ async fn test_set_works_cross_shard_disabled() { let reply = client.read_until('Z').await.unwrap(); assert_eq!(reply.len(), 2); } + +#[tokio::test] +async fn test_ambiguous_schema_sharded_query_errors_when_cross_shard_disabled() { + let table = "schema_shard_ambiguous_test"; + + let mut setup = TestClient::new_sharded(Parameters::default()).await; + for stmt in [ + "CREATE SCHEMA IF NOT EXISTS acustomer".to_string(), + "CREATE SCHEMA IF NOT EXISTS bcustomer".to_string(), + format!("CREATE TABLE IF NOT EXISTS acustomer.{table} (id INT)"), + format!("CREATE TABLE IF NOT EXISTS bcustomer.{table} (id INT)"), + ] { + setup.send_simple(Query::new(&stmt)).await; + setup.read_until('Z').await.unwrap(); + } + + let mut client = TestClient::new_cross_shard_disabled(Parameters::default()).await; + client + .send_simple(Query::new(&format!("SELECT * FROM {table}"))) + .await; + let err = client.read_until('Z').await.unwrap_err(); + assert_eq!(err.code, "58000"); + assert_eq!(err.message, "cross-shard queries are disabled"); +} diff --git a/pgdog/src/frontend/router/parser/query/delete.rs b/pgdog/src/frontend/router/parser/query/delete.rs index ddba1b9e3..6a9dea93d 100644 --- a/pgdog/src/frontend/router/parser/query/delete.rs +++ b/pgdog/src/frontend/router/parser/query/delete.rs @@ -13,12 +13,6 @@ impl QueryParser { self.recorder_mut(), ); - let is_sharded = parser.is_sharded( - &context.router_context.schema, - context.router_context.cluster.user(), - context.router_context.parameter_hints.search_path, - ); - let shard = parser.shard()?; if let Some(shard) = shard { @@ -32,14 +26,34 @@ impl QueryParser { .shards_calculator .push(ShardWithPriority::new_table(shard)); } else { - if let Some(recorder) = self.recorder_mut() { - recorder.record_entry(None, "DELETE fell back to broadcast"); - } - if is_sharded { + let schema_shard_state = parser.schema_shard_state( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + let is_sharded = parser.is_sharded( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + if let SchemaShardState::Resolved { shard, schema } = schema_shard_state { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(Some(shard.clone()), "DELETE matched schema"); + } + context + .shards_calculator + .push(ShardWithPriority::new_search_path(shard, &schema)); + } else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(None, "DELETE fell back to broadcast"); + } context .shards_calculator .push(ShardWithPriority::new_table(Shard::All)); } else { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(None, "DELETE fell back to omnisharded"); + } context .shards_calculator .push(ShardWithPriority::new_rr_omni(Shard::All)); diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 36eb8c11f..d405c8257 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -17,6 +17,7 @@ use crate::{ plugin::plugins, }; +use super::statement::SchemaShardState; use super::{ explain_trace::{ExplainRecorder, ExplainSummary}, *, @@ -530,18 +531,29 @@ impl QueryParser { ) .with_schema_lookup(schema_lookup); - let is_sharded = parser.is_sharded( - &context.router_context.schema, - context.router_context.cluster.user(), - context.router_context.parameter_hints.search_path, - ); - - let shard = parser.shard()?.unwrap_or(Shard::All); + let shard = parser.shard()?; - context.shards_calculator.push(if is_sharded { - ShardWithPriority::new_table(shard.clone()) + context.shards_calculator.push(if let Some(shard) = shard { + ShardWithPriority::new_table(shard) } else { - ShardWithPriority::new_table_omni(shard) + let schema_shard_state = parser.schema_shard_state( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + let is_sharded = parser.is_sharded( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + + if let SchemaShardState::Resolved { shard, schema } = schema_shard_state { + ShardWithPriority::new_search_path(shard, &schema) + } else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded { + ShardWithPriority::new_table(Shard::All) + } else { + ShardWithPriority::new_table_omni(Shard::All) + } }); let shard = context.shards_calculator.shard(); diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bcfc38370..5c8e17660 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -57,7 +57,7 @@ impl QueryParser { let mut shards = HashSet::new(); - let (shard, is_sharded, tables, advisory_locks) = { + let (shard, schema_shard_state, is_sharded, tables, advisory_locks) = { let mut statement_parser = StatementParser::from_select( stmt, context.router_context.bind, @@ -72,10 +72,16 @@ impl QueryParser { let advisory_locks = statement_parser.extract_advisory_locks(); if shard.is_some() { - (shard, true, vec![], advisory_locks) + (shard, SchemaShardState::None, true, vec![], advisory_locks) } else { + let schema_shard_state = statement_parser.schema_shard_state( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); ( None, + schema_shard_state, statement_parser.is_sharded( &context.router_context.schema, context.router_context.cluster.user(), @@ -148,6 +154,18 @@ impl QueryParser { context .shards_calculator .push(ShardWithPriority::new_table(shard)); + } else if let SchemaShardState::Resolved { shard, schema } = schema_shard_state { + debug!("resolved schema-sharded query via search_path/default schema"); + + context + .shards_calculator + .push(ShardWithPriority::new_search_path(shard, &schema)); + } else if matches!(schema_shard_state, SchemaShardState::Ambiguous) { + debug!("schema-sharded query is ambiguous, routing as cross-shard"); + + context + .shards_calculator + .push(ShardWithPriority::new_table(Shard::All)); } else if is_sharded { debug!("table is sharded, but no sharding key detected"); diff --git a/pgdog/src/frontend/router/parser/query/update.rs b/pgdog/src/frontend/router/parser/query/update.rs index f4e78f4f0..17034f5a7 100644 --- a/pgdog/src/frontend/router/parser/query/update.rs +++ b/pgdog/src/frontend/router/parser/query/update.rs @@ -13,12 +13,6 @@ impl QueryParser { self.recorder_mut(), ); - let is_sharded = parser.is_sharded( - &context.router_context.schema, - context.router_context.cluster.user(), - context.router_context.parameter_hints.search_path, - ); - let shard = parser.shard()?; if let Some(shard) = shard { if let Some(recorder) = self.recorder_mut() { @@ -31,14 +25,34 @@ impl QueryParser { .shards_calculator .push(ShardWithPriority::new_table(shard)); } else { - if let Some(recorder) = self.recorder_mut() { - recorder.record_entry(None, "UPDATE fell back to broadcast"); - } - if is_sharded { + let schema_shard_state = parser.schema_shard_state( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + let is_sharded = parser.is_sharded( + &context.router_context.schema, + context.router_context.cluster.user(), + context.router_context.parameter_hints.search_path, + ); + if let SchemaShardState::Resolved { shard, schema } = schema_shard_state { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(Some(shard.clone()), "UPDATE matched schema"); + } + context + .shards_calculator + .push(ShardWithPriority::new_search_path(shard, &schema)); + } else if matches!(schema_shard_state, SchemaShardState::Ambiguous) || is_sharded { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(None, "UPDATE fell back to broadcast"); + } context .shards_calculator .push(ShardWithPriority::new_table(Shard::All)); } else { + if let Some(recorder) = self.recorder_mut() { + recorder.record_entry(None, "UPDATE fell back to omnisharded"); + } context .shards_calculator .push(ShardWithPriority::new_table_omni(Shard::All)); diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs index cec05c806..3b3ecb087 100644 --- a/pgdog/src/frontend/router/parser/statement.rs +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -407,6 +407,13 @@ enum Statement<'a> { Insert(&'a InsertStmt), } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum SchemaShardState { + None, + Resolved { shard: Shard, schema: String }, + Ambiguous, +} + /// Context for looking up table columns from the database schema. /// Used for INSERT statements without explicit column lists. pub struct SchemaLookupContext<'a> { @@ -604,6 +611,52 @@ impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { Ok(None) } + pub(crate) fn schema_shard_state( + &mut self, + db_schema: &Schema, + user: &str, + search_path: Option<&ParameterValue>, + ) -> SchemaShardState { + if self.schema.schemas.is_empty() { + return SchemaShardState::None; + } + + let tables = self.tables().to_vec(); + let mut schema_sharder = SchemaSharder::default(); + let default_schema_mapping = self.schema.schemas.get(None).is_some(); + let mut ambiguous = false; + + for table in tables { + if let Some(schema) = table.schema { + schema_sharder.resolve(Some(schema.into()), &self.schema.schemas); + continue; + } + + if let Some(relation) = db_schema.table(table, user, search_path) { + schema_sharder.resolve(Some(relation.schema().into()), &self.schema.schemas); + continue; + } + + ambiguous |= default_schema_mapping + || self + .schema + .schemas + .keys() + .any(|schema| db_schema.get(schema, table.name).is_some()); + } + + if ambiguous { + SchemaShardState::Ambiguous + } else if let Some((shard, schema)) = schema_sharder.get() { + SchemaShardState::Resolved { + shard, + schema: schema.to_owned(), + } + } else { + SchemaShardState::None + } + } + /// Check that the query references a table that contains a sharded /// column. This check is needed in case sharded tables config /// doesn't specify a table name and should short-circuit if it does. @@ -2549,6 +2602,86 @@ mod test { assert_eq!(result2, Some(Shard::Direct(2))); } + fn make_test_schema_with_sharded_relations() -> crate::backend::Schema { + let relations = HashMap::from([ + ( + ("sales".into(), "products".into()), + Relation::test_table("sales", "products", IndexMap::new()), + ), + ( + ("inventory".into(), "products".into()), + Relation::test_table("inventory", "products", IndexMap::new()), + ), + ( + ("public".into(), "unsharded_table".into()), + Relation::test_table("public", "unsharded_table", IndexMap::new()), + ), + ]); + crate::backend::Schema::from_parts(vec!["$user".into(), "public".into()], relations) + } + + fn run_schema_shard_state_test( + stmt: &str, + search_path: Option, + ) -> Result { + let schema = ShardingSchema { + shards: 3, + schemas: ShardedSchemas::new(vec![ + ShardedSchema { + database: "test".to_string(), + name: Some("sales".to_string()), + shard: 1, + all: false, + }, + ShardedSchema { + database: "test".to_string(), + name: Some("inventory".to_string()), + shard: 2, + all: false, + }, + ]), + ..Default::default() + }; + let db_schema = make_test_schema_with_sharded_relations(); + let raw = pg_query::parse(stmt) + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + let mut parser = StatementParser::from_raw(&raw, None, &schema, None)?; + Ok(parser.schema_shard_state(&db_schema, "pgdog", search_path.as_ref())) + } + + #[test] + fn test_schema_shard_state_ambiguous_without_search_path() { + let result = run_schema_shard_state_test("SELECT * FROM products", None).unwrap(); + assert_eq!(result, SchemaShardState::Ambiguous); + } + + #[test] + fn test_schema_shard_state_resolved_from_search_path() { + let result = run_schema_shard_state_test( + "SELECT * FROM products", + Some(ParameterValue::Tuple(vec!["sales".into(), "public".into()])), + ) + .unwrap(); + assert_eq!( + result, + SchemaShardState::Resolved { + shard: Shard::Direct(1), + schema: "sales".into(), + } + ); + } + + #[test] + fn test_schema_shard_state_none_for_unsharded_table() { + let result = run_schema_shard_state_test("SELECT * FROM unsharded_table", None).unwrap(); + assert_eq!(result, SchemaShardState::None); + } + // Column-only sharded table detection tests (using loaded schema) fn run_test_column_only(stmt: &str, bind: Option<&Bind>) -> Result, Error> {