diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 9e0683bdd7b20..c2b595cd2697f 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -51,6 +51,11 @@ pub trait Dialect: Send + Sync { /// Return the character used to quote identifiers. fn identifier_quote_style(&self, _identifier: &str) -> Option; + /// Whether array literals should be rendered with the `ARRAY[...]` keyword. + fn use_array_keyword_for_array_literals(&self) -> bool { + false + } + /// Does the dialect support specifying `NULLS FIRST/LAST` in `ORDER BY` clauses? fn supports_nulls_first_in_sort(&self) -> bool { true @@ -321,6 +326,10 @@ impl Dialect for DefaultDialect { pub struct PostgreSqlDialect {} impl Dialect for PostgreSqlDialect { + fn use_array_keyword_for_array_literals(&self) -> bool { + true + } + fn supports_qualify(&self) -> bool { false } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3601febe744c9..e7a7d04e58b75 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -604,7 +604,7 @@ impl Unparser<'_> { .collect::>>()?; Ok(ast::Expr::Array(Array { elem: args, - named: false, + named: self.dialect.use_array_keyword_for_array_literals(), })) } @@ -615,7 +615,10 @@ impl Unparser<'_> { elem.push(self.scalar_to_sql(&value)?); } - Ok(ast::Expr::Array(Array { elem, named: false })) + Ok(ast::Expr::Array(Array { + elem, + named: self.dialect.use_array_keyword_for_array_literals(), + })) } fn array_element_to_sql(&self, args: &[Expr]) -> Result { @@ -3042,6 +3045,61 @@ mod tests { } } + #[test] + fn test_array_literal_scalar_value_to_sql_postgres() -> Result<()> { + let dialect: Arc = Arc::new(PostgreSqlDialect {}); + let unparser = Unparser::new(dialect.as_ref()); + + let expr = Expr::Literal( + ScalarValue::List(ScalarValue::new_list_nullable( + &[ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(3)), + ], + &DataType::Int32, + )), + None, + ); + + let ast = unparser.expr_to_sql(&expr)?; + assert_eq!(ast.to_string(), "ARRAY[1, 2, 3]"); + + Ok(()) + } + + #[test] + fn test_nested_array_literal_scalar_value_to_sql_postgres() -> Result<()> { + let dialect: Arc = Arc::new(PostgreSqlDialect {}); + let unparser = Unparser::new(dialect.as_ref()); + + let inner_type = DataType::Int32; + let nested_type = + DataType::List(Arc::new(Field::new_list_field(inner_type.clone(), true))); + + let expr = Expr::Literal( + ScalarValue::List(ScalarValue::new_list_nullable( + &[ + ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2))], + &inner_type, + )), + ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Int32(Some(3)), ScalarValue::Int32(Some(4))], + &inner_type, + )), + ], + &nested_type, + )), + None, + ); + + let ast = unparser.expr_to_sql(&expr)?; + assert_eq!(ast.to_string(), "ARRAY[ARRAY[1, 2], ARRAY[3, 4]]"); + + Ok(()) + } + #[test] fn test_round_scalar_fn_to_expr() -> Result<()> { let default_dialect: Arc = Arc::new( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 0dad48b168976..fa4161168bac1 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -2728,6 +2728,17 @@ fn test_unparse_window() -> Result<()> { Ok(()) } +#[test] +fn test_array_to_sql_postgres() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT [1, 2, 3, 4, 5]", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserPostgreSqlDialect {}, + expected: @"SELECT ARRAY[1, 2, 3, 4, 5]", + ); + Ok(()) +} + #[test] fn test_like_filter() { let statement = generate_round_trip_statement(