diff --git a/README.md b/README.md index 17c24c9..e14b459 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,19 @@ for _, stmt := range statements { } ``` +- Convenience helpers for single statements and expression fragments + +```Go +// Parse exactly one statement (errors if the input has zero or multiple statements) +stmt, err := clickhouse.ParseStmt("SELECT a FROM t") + +// Parse an expression fragment, e.g. a column expression or function call +expr, err := clickhouse.ParseExpr("toDate(created_at) + 1") + +// Format AST into multi-line indented SQL (same output as the -beautify CLI flag) +fmt.Println(clickhouse.FormatBeautify(stmt)) +``` + ## AST Traversal ### Walk Pattern (Recommended) diff --git a/main.go b/main.go index d9cf7c1..abc0dce 100644 --- a/main.go +++ b/main.go @@ -69,10 +69,7 @@ func main() { } else { // format SQL for _, stmt := range stmts { if options.beautify { - formatter := clickhouse.NewFormatter() - formatter.WithBeautify() - formatter.WriteExpr(stmt) - fmt.Println(formatter.String()) + fmt.Println(clickhouse.FormatBeautify(stmt)) } else { fmt.Println(clickhouse.Format(stmt)) } diff --git a/parser/api.go b/parser/api.go new file mode 100644 index 0000000..0473eda --- /dev/null +++ b/parser/api.go @@ -0,0 +1,50 @@ +package parser + +import ( + "errors" + "fmt" +) + +// ParseStmt parses exactly one statement from the input. It returns an error +// if the input contains no statement or more than one statement. To parse a +// script with multiple statements, use NewParser(sql).ParseStmts(). +func ParseStmt(sql string) (Expr, error) { + stmts, err := NewParser(sql).ParseStmts() + if err != nil { + return nil, err + } + if len(stmts) == 0 { + return nil, errors.New("no statement found in input") + } + if len(stmts) > 1 { + return nil, fmt.Errorf("expected exactly one statement, but found %d", len(stmts)) + } + return stmts[0], nil +} + +// ParseExpr parses a single expression fragment, such as a column reference, +// a function call or an arithmetic expression — e.g. `toDate(created_at) + 1`. +// The whole input must be consumed by the expression. +func ParseExpr(sql string) (Expr, error) { + p := NewParser(sql) + if err := p.lexer.consumeToken(); err != nil { + return nil, p.wrapError(err) + } + expr, err := p.parseExpr(p.Pos()) + if err != nil { + return nil, p.wrapError(err) + } + if p.last() != nil { + return nil, p.wrapError(fmt.Errorf("unexpected token after expression: %q", p.lastTokenString())) + } + return expr, nil +} + +// FormatBeautify renders an expression into multi-line indented SQL. It is a +// convenience for NewFormatter().WithBeautify(); use Format for compact +// single-line SQL. +func FormatBeautify(expr Expr) string { + formatter := NewFormatter().WithBeautify() + formatter.WriteExpr(expr) + return formatter.String() +} diff --git a/parser/api_test.go b/parser/api_test.go new file mode 100644 index 0000000..76f5194 --- /dev/null +++ b/parser/api_test.go @@ -0,0 +1,93 @@ +package parser + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseStmt(t *testing.T) { + t.Run("single statement", func(t *testing.T) { + stmt, err := ParseStmt("SELECT a FROM t") + require.NoError(t, err) + require.IsType(t, &SelectQuery{}, stmt) + }) + + t.Run("single statement with trailing semicolon", func(t *testing.T) { + stmt, err := ParseStmt("SELECT a FROM t;") + require.NoError(t, err) + require.IsType(t, &SelectQuery{}, stmt) + }) + + t.Run("empty input", func(t *testing.T) { + _, err := ParseStmt("") + require.ErrorContains(t, err, "no statement") + }) + + t.Run("multiple statements", func(t *testing.T) { + _, err := ParseStmt("SELECT 1; SELECT 2") + require.ErrorContains(t, err, "exactly one statement") + }) + + t.Run("syntax error", func(t *testing.T) { + _, err := ParseStmt("SELECT a FROM WHERE") + require.Error(t, err) + }) +} + +func TestParseExpr(t *testing.T) { + t.Run("column reference", func(t *testing.T) { + expr, err := ParseExpr("a") + require.NoError(t, err) + require.IsType(t, &Ident{}, expr) + }) + + t.Run("function call", func(t *testing.T) { + expr, err := ParseExpr("toDate(created_at) + 1") + require.NoError(t, err) + require.Equal(t, "toDate(created_at) + 1", Format(expr)) + }) + + t.Run("case expression", func(t *testing.T) { + expr, err := ParseExpr("CASE WHEN a > 1 THEN 'x' ELSE 'y' END") + require.NoError(t, err) + require.IsType(t, &CaseExpr{}, expr) + }) + + t.Run("empty input", func(t *testing.T) { + _, err := ParseExpr("") + require.Error(t, err) + }) + + t.Run("trailing tokens", func(t *testing.T) { + _, err := ParseExpr("a + 1 b") + require.ErrorContains(t, err, "unexpected token after expression") + }) + + t.Run("syntax error", func(t *testing.T) { + _, err := ParseExpr("f(") + require.Error(t, err) + }) +} + +func TestFormatBeautify(t *testing.T) { + stmt, err := ParseStmt("SELECT a, b FROM t WHERE a = 1") + require.NoError(t, err) + + beautified := FormatBeautify(stmt) + require.Contains(t, beautified, "\n") + + // Must match the long-form formatter API it wraps. + formatter := NewFormatter().WithBeautify() + formatter.WriteExpr(stmt) + require.Equal(t, formatter.String(), beautified) + + // Beautified SQL must still parse to the same compact form. + reparsed, err := ParseStmt(beautified) + require.NoError(t, err) + require.Equal(t, Format(stmt), Format(reparsed)) + + require.Equal(t, "", FormatBeautify(nil)) + require.False(t, strings.Contains(Format(stmt), "\n")) +}