Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/ast/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3522,6 +3522,28 @@ impl fmt::Display for CreateDomain {
}
}

/// The return type of a `CREATE FUNCTION` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionReturnType {
/// `RETURNS <type>`
DataType(DataType),
/// `RETURNS SETOF <type>`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead move the link to the postgres docs here? since the enum itself applies to all dialects, whereas only this variant is pg specific

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

///
/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html)
SetOf(DataType),
}

impl fmt::Display for FunctionReturnType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionReturnType::DataType(data_type) => write!(f, "{data_type}"),
FunctionReturnType::SetOf(data_type) => write!(f, "SETOF {data_type}"),
}
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand All @@ -3542,7 +3564,7 @@ pub struct CreateFunction {
/// List of arguments for the function.
pub args: Option<Vec<OperateFunctionArg>>,
/// The return type of the function.
pub return_type: Option<DataType>,
pub return_type: Option<FunctionReturnType>,
/// The expression that defines the function.
///
/// Examples:
Expand Down
14 changes: 7 additions & 7 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ pub use self::ddl::{
CreatePolicyCommand, CreatePolicyType, CreateTable, CreateTrigger, CreateView, Deduplicate,
DeferrableInitial, DistStyle, DropBehavior, DropExtension, DropFunction, DropOperator,
DropOperatorClass, DropOperatorFamily, DropOperatorSignature, DropPolicy, DropTrigger,
ForValues, GeneratedAs, GeneratedExpressionMode, IdentityParameters, IdentityProperty,
IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, IndexColumn,
IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, OperatorArgTypes,
OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem, OperatorOption, OperatorPurpose,
Owner, Partition, PartitionBoundValue, ProcedureParam, ReferentialAction, RenameTableNameKind,
ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate,
UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength,
ForValues, FunctionReturnType, GeneratedAs, GeneratedExpressionMode, IdentityParameters,
IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder,
IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption,
OperatorArgTypes, OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem,
OperatorOption, OperatorPurpose, Owner, Partition, PartitionBoundValue, ProcedureParam,
ReferentialAction, RenameTableNameKind, ReplicaIdentity, TagsColumnOption, TriggerObjectKind,
Truncate, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength,
UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, UserDefinedTypeSqlDefinitionOption,
UserDefinedTypeStorage, ViewColumnDef,
};
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,7 @@ define_keywords!(
SESSION_USER,
SET,
SETERROR,
SETOF,
SETS,
SETTINGS,
SHARE,
Expand Down
22 changes: 15 additions & 7 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5581,7 +5581,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::RParen)?;

let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?)
Some(self.parse_function_return_type()?)
} else {
None
};
Expand Down Expand Up @@ -5761,7 +5761,7 @@ impl<'a> Parser<'a> {
let (name, args) = self.parse_create_function_name_and_params()?;

let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?)
Some(self.parse_function_return_type()?)
} else {
None
};
Expand Down Expand Up @@ -5864,11 +5864,11 @@ impl<'a> Parser<'a> {
})
})?;

let return_type = if return_table.is_some() {
return_table
} else {
Some(self.parse_data_type()?)
let data_type = match return_table {
Some(table_type) => table_type,
None => self.parse_data_type()?,
};
let return_type = Some(FunctionReturnType::DataType(data_type));

let _ = self.parse_keyword(Keyword::AS);

Expand Down Expand Up @@ -5920,6 +5920,14 @@ impl<'a> Parser<'a> {
})
}

fn parse_function_return_type(&mut self) -> Result<FunctionReturnType, ParserError> {
if self.parse_keyword(Keyword::SETOF) {
Ok(FunctionReturnType::SetOf(self.parse_data_type()?))
} else {
Ok(FunctionReturnType::DataType(self.parse_data_type()?))
}
}

fn parse_create_function_name_and_params(
&mut self,
) -> Result<(ObjectName, Vec<OperateFunctionArg>), ParserError> {
Expand Down Expand Up @@ -8561,7 +8569,7 @@ impl<'a> Parser<'a> {
}
}

/// Parse a single [PartitionBoundValue].
/// Parse a single partition bound value (MINVALUE, MAXVALUE, or expression).
fn parse_partition_bound_value(&mut self) -> Result<PartitionBoundValue, ParserError> {
if self.parse_keyword(Keyword::MINVALUE) {
Ok(PartitionBoundValue::MinValue)
Expand Down
2 changes: 1 addition & 1 deletion tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,7 @@ fn test_bigquery_create_function() {
Ident::new("myfunction"),
]),
args: Some(vec![OperateFunctionArg::with_name("x", DataType::Float64),]),
return_type: Some(DataType::Float64),
return_type: Some(FunctionReturnType::DataType(DataType::Float64)),
function_body: Some(CreateFunctionBody::AsAfterOptions(Expr::Value(
number("42").with_empty_span()
))),
Expand Down
4 changes: 2 additions & 2 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ fn parse_create_function() {
default_expr: None,
},
]),
return_type: Some(DataType::Int(None)),
return_type: Some(FunctionReturnType::DataType(DataType::Int(None))),
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken::empty(),
statements: vec![Statement::Return(ReturnStatement {
Expand Down Expand Up @@ -430,7 +430,7 @@ fn parse_create_function_parameter_default_values() {
data_type: DataType::Int(None),
default_expr: Some(Expr::Value((number("42")).with_empty_span())),
},]),
return_type: Some(DataType::Int(None)),
return_type: Some(FunctionReturnType::DataType(DataType::Int(None))),
function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
begin_token: AttachedToken::empty(),
statements: vec![Statement::Return(ReturnStatement {
Expand Down
42 changes: 33 additions & 9 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4441,7 +4441,7 @@ $$"#;
DataType::Varchar(None),
),
]),
return_type: Some(DataType::Boolean),
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
Expand Down Expand Up @@ -4484,7 +4484,7 @@ $$"#;
DataType::Int(None)
)
]),
return_type: Some(DataType::Boolean),
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
Expand Down Expand Up @@ -4531,7 +4531,7 @@ $$"#;
DataType::Int(None)
),
]),
return_type: Some(DataType::Boolean),
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
Expand Down Expand Up @@ -4578,7 +4578,7 @@ $$"#;
DataType::Int(None)
),
]),
return_type: Some(DataType::Boolean),
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
Expand Down Expand Up @@ -4618,7 +4618,7 @@ $$"#;
),
OperateFunctionArg::with_name("b", DataType::Varchar(None)),
]),
return_type: Some(DataType::Boolean),
return_type: Some(FunctionReturnType::DataType(DataType::Boolean)),
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
Expand Down Expand Up @@ -4661,7 +4661,7 @@ fn parse_create_function() {
OperateFunctionArg::unnamed(DataType::Integer(None)),
OperateFunctionArg::unnamed(DataType::Integer(None)),
]),
return_type: Some(DataType::Integer(None)),
return_type: Some(FunctionReturnType::DataType(DataType::Integer(None))),
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::Strict),
Expand Down Expand Up @@ -4698,6 +4698,30 @@ fn parse_create_function_detailed() {
);
}

#[test]
fn parse_create_function_returns_setof() {
pg_and_generic().verified_stmt(
"CREATE FUNCTION get_users() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM users'",
);
pg_and_generic().verified_stmt(
"CREATE FUNCTION get_ids() RETURNS SETOF INTEGER LANGUAGE sql AS 'SELECT id FROM users'",
);
pg_and_generic().verified_stmt(
r#"CREATE FUNCTION get_all() RETURNS SETOF my_schema."MyType" LANGUAGE sql AS 'SELECT * FROM t'"#,
);
pg_and_generic().verified_stmt(
"CREATE FUNCTION get_rows() RETURNS SETOF RECORD LANGUAGE sql AS 'SELECT * FROM t'",
);

let sql = "CREATE FUNCTION get_names() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM t'";
match pg_and_generic().verified_stmt(sql) {
Statement::CreateFunction(CreateFunction { return_type, .. }) => {
assert_eq!(return_type, Some(FunctionReturnType::SetOf(DataType::Text)));
}
_ => panic!("Expected CreateFunction"),
}
}

#[test]
fn parse_create_function_with_security() {
let sql =
Expand Down Expand Up @@ -4773,10 +4797,10 @@ fn parse_create_function_c_with_module_pathname() {
"input",
DataType::Custom(ObjectName::from(vec![Ident::new("cstring")]), vec![]),
),]),
return_type: Some(DataType::Custom(
return_type: Some(FunctionReturnType::DataType(DataType::Custom(
ObjectName::from(vec![Ident::new("cas")]),
vec![]
)),
))),
language: Some("c".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: None,
Expand Down Expand Up @@ -6491,7 +6515,7 @@ fn parse_trigger_related_functions() {
if_not_exists: false,
name: ObjectName::from(vec![Ident::new("emp_stamp")]),
args: Some(vec![]),
return_type: Some(DataType::Trigger),
return_type: Some(FunctionReturnType::DataType(DataType::Trigger)),
function_body: Some(
CreateFunctionBody::AsBeforeOptions {
body: Expr::Value((
Expand Down
Loading