diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index d6da8368e..261c51b5a 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -3522,6 +3522,28 @@ impl fmt::Display for CreateDomain { } } +/// The return type of a `CREATE FUNCTION` statement. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum FunctionReturnType { + /// `RETURNS ` + DataType(DataType), + /// `RETURNS SETOF ` + /// + /// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html) + SetOf(DataType), +} + +impl fmt::Display for FunctionReturnType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionReturnType::DataType(data_type) => write!(f, "{data_type}"), + FunctionReturnType::SetOf(data_type) => write!(f, "SETOF {data_type}"), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -3542,7 +3564,7 @@ pub struct CreateFunction { /// List of arguments for the function. pub args: Option>, /// The return type of the function. - pub return_type: Option, + pub return_type: Option, /// The expression that defines the function. /// /// Examples: diff --git a/src/ast/mod.rs b/src/ast/mod.rs index e201f7842..3bdf1b751 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -72,13 +72,13 @@ pub use self::ddl::{ CreatePolicyCommand, CreatePolicyType, CreateTable, CreateTrigger, CreateView, Deduplicate, DeferrableInitial, DistStyle, DropBehavior, DropExtension, DropFunction, DropOperator, DropOperatorClass, DropOperatorFamily, DropOperatorSignature, DropPolicy, DropTrigger, - ForValues, GeneratedAs, GeneratedExpressionMode, IdentityParameters, IdentityProperty, - IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, IndexColumn, - IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, OperatorArgTypes, - OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem, OperatorOption, OperatorPurpose, - Owner, Partition, PartitionBoundValue, ProcedureParam, ReferentialAction, RenameTableNameKind, - ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate, - UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength, + ForValues, FunctionReturnType, GeneratedAs, GeneratedExpressionMode, IdentityParameters, + IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, + IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, + OperatorArgTypes, OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem, + OperatorOption, OperatorPurpose, Owner, Partition, PartitionBoundValue, ProcedureParam, + ReferentialAction, RenameTableNameKind, ReplicaIdentity, TagsColumnOption, TriggerObjectKind, + Truncate, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength, UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, UserDefinedTypeSqlDefinitionOption, UserDefinedTypeStorage, ViewColumnDef, }; diff --git a/src/keywords.rs b/src/keywords.rs index 9ea85fd3a..65cff9c10 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -936,6 +936,7 @@ define_keywords!( SESSION_USER, SET, SETERROR, + SETOF, SETS, SETTINGS, SHARE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 274449ff7..57cdd1867 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5581,7 +5581,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) + Some(self.parse_function_return_type()?) } else { None }; @@ -5761,7 +5761,7 @@ impl<'a> Parser<'a> { let (name, args) = self.parse_create_function_name_and_params()?; let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) + Some(self.parse_function_return_type()?) } else { None }; @@ -5864,11 +5864,11 @@ impl<'a> Parser<'a> { }) })?; - let return_type = if return_table.is_some() { - return_table - } else { - Some(self.parse_data_type()?) + let data_type = match return_table { + Some(table_type) => table_type, + None => self.parse_data_type()?, }; + let return_type = Some(FunctionReturnType::DataType(data_type)); let _ = self.parse_keyword(Keyword::AS); @@ -5920,6 +5920,14 @@ impl<'a> Parser<'a> { }) } + fn parse_function_return_type(&mut self) -> Result { + if self.parse_keyword(Keyword::SETOF) { + Ok(FunctionReturnType::SetOf(self.parse_data_type()?)) + } else { + Ok(FunctionReturnType::DataType(self.parse_data_type()?)) + } + } + fn parse_create_function_name_and_params( &mut self, ) -> Result<(ObjectName, Vec), ParserError> { @@ -8561,7 +8569,7 @@ impl<'a> Parser<'a> { } } - /// Parse a single [PartitionBoundValue]. + /// Parse a single partition bound value (MINVALUE, MAXVALUE, or expression). fn parse_partition_bound_value(&mut self) -> Result { if self.parse_keyword(Keyword::MINVALUE) { Ok(PartitionBoundValue::MinValue) diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index a6b0906ff..bfc8f95e7 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -2289,7 +2289,7 @@ fn test_bigquery_create_function() { Ident::new("myfunction"), ]), args: Some(vec![OperateFunctionArg::with_name("x", DataType::Float64),]), - return_type: Some(DataType::Float64), + return_type: Some(FunctionReturnType::DataType(DataType::Float64)), function_body: Some(CreateFunctionBody::AsAfterOptions(Expr::Value( number("42").with_empty_span() ))), diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index da6ecace6..bb8bed2a5 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -255,7 +255,7 @@ fn parse_create_function() { default_expr: None, }, ]), - return_type: Some(DataType::Int(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Int(None))), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { begin_token: AttachedToken::empty(), statements: vec![Statement::Return(ReturnStatement { @@ -430,7 +430,7 @@ fn parse_create_function_parameter_default_values() { data_type: DataType::Int(None), default_expr: Some(Expr::Value((number("42")).with_empty_span())), },]), - return_type: Some(DataType::Int(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Int(None))), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { begin_token: AttachedToken::empty(), statements: vec![Statement::Return(ReturnStatement { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 60aca14b3..a5d55df04 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -4441,7 +4441,7 @@ $$"#; DataType::Varchar(None), ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4484,7 +4484,7 @@ $$"#; DataType::Int(None) ) ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4531,7 +4531,7 @@ $$"#; DataType::Int(None) ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4578,7 +4578,7 @@ $$"#; DataType::Int(None) ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4618,7 +4618,7 @@ $$"#; ), OperateFunctionArg::with_name("b", DataType::Varchar(None)), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4661,7 +4661,7 @@ fn parse_create_function() { OperateFunctionArg::unnamed(DataType::Integer(None)), OperateFunctionArg::unnamed(DataType::Integer(None)), ]), - return_type: Some(DataType::Integer(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Integer(None))), language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), called_on_null: Some(FunctionCalledOnNull::Strict), @@ -4698,6 +4698,30 @@ fn parse_create_function_detailed() { ); } +#[test] +fn parse_create_function_returns_setof() { + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_users() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM users'", + ); + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_ids() RETURNS SETOF INTEGER LANGUAGE sql AS 'SELECT id FROM users'", + ); + pg_and_generic().verified_stmt( + r#"CREATE FUNCTION get_all() RETURNS SETOF my_schema."MyType" LANGUAGE sql AS 'SELECT * FROM t'"#, + ); + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_rows() RETURNS SETOF RECORD LANGUAGE sql AS 'SELECT * FROM t'", + ); + + let sql = "CREATE FUNCTION get_names() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM t'"; + match pg_and_generic().verified_stmt(sql) { + Statement::CreateFunction(CreateFunction { return_type, .. }) => { + assert_eq!(return_type, Some(FunctionReturnType::SetOf(DataType::Text))); + } + _ => panic!("Expected CreateFunction"), + } +} + #[test] fn parse_create_function_with_security() { let sql = @@ -4773,10 +4797,10 @@ fn parse_create_function_c_with_module_pathname() { "input", DataType::Custom(ObjectName::from(vec![Ident::new("cstring")]), vec![]), ),]), - return_type: Some(DataType::Custom( + return_type: Some(FunctionReturnType::DataType(DataType::Custom( ObjectName::from(vec![Ident::new("cas")]), vec![] - )), + ))), language: Some("c".into()), behavior: Some(FunctionBehavior::Immutable), called_on_null: None, @@ -6491,7 +6515,7 @@ fn parse_trigger_related_functions() { if_not_exists: false, name: ObjectName::from(vec![Ident::new("emp_stamp")]), args: Some(vec![]), - return_type: Some(DataType::Trigger), + return_type: Some(FunctionReturnType::DataType(DataType::Trigger)), function_body: Some( CreateFunctionBody::AsBeforeOptions { body: Expr::Value((