diff --git a/internal/poet/reserved.go b/internal/poet/reserved.go index d9d3198..6077ae6 100644 --- a/internal/poet/reserved.go +++ b/internal/poet/reserved.go @@ -4,25 +4,37 @@ import "slices" // TODO(quentin@escape.tech): check if this is complete var reservedKeywords = []string{ - "class", - "if", - "else", - "elif", - "not", - "for", "and", - "in", - "is", - "or", - "with", "as", "assert", + "async", + "await", "break", + "class", + "continue", + "def", + "del", + "elif", + "else", "except", "finally", - "try", + "for", + "from", + "global", + "if", + "import", + "in", + "is", + "lambda", + "nonlocal", + "not", + "or", + "pass", "raise", "return", + "try", + "while", + "with", "yield", } diff --git a/internal/printer/printer.go b/internal/printer/printer.go index 0660c6a..c5f2c4e 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -37,6 +37,22 @@ func (w *writer) printIndent(indent int32) { } } +func (w *writer) printCommentText(text string, indent int32) { + lines := strings.Split(text, "\n") + for _, line := range lines { + w.print("#") + // trim right space which is usually unintended, + // but leave left space untouched in case if it's intentionally formatted. + trimmed := strings.TrimRight(line, " ") + if trimmed != "" { + w.print(" ") + w.print(trimmed) + } + w.print("\n") + w.printIndent(indent) + } +} + func (w *writer) printNode(node *ast.Node, indent int32) { switch n := node.Node.(type) { @@ -132,10 +148,7 @@ func (w *writer) printNode(node *ast.Node, indent int32) { func (w *writer) printAnnAssign(aa *ast.AnnAssign, indent int32) { if aa.Comment != "" { - w.print("# ") - w.print(aa.Comment) - w.print("\n") - w.printIndent(indent) + w.printCommentText(aa.Comment, indent) } w.printName(aa.Target, indent) w.print(": ") @@ -255,10 +268,7 @@ func (w *writer) printClassDef(cd *ast.ClassDef, indent int32) { if i == 0 { if e, ok := node.Node.(*ast.Node_Expr); ok { if c, ok := e.Expr.Value.Node.(*ast.Node_Constant); ok { - w.print(`""`) - w.printConstant(c.Constant, indent) - w.print(`""`) - w.print("\n") + w.printDocString(c.Constant, indent) continue } } @@ -268,6 +278,33 @@ func (w *writer) printClassDef(cd *ast.ClassDef, indent int32) { } } +func (w *writer) printDocString(c *ast.Constant, indent int32) { + switch n := c.Value.(type) { + case *ast.Constant_Str: + w.print(`"""`) + lines := strings.Split(n.Str, "\n") + printedN := 0 + for n, line := range lines { + // trim right space which is usually unintended, + // but leave left space untouched in case if it's intentionally formatted. + trimmed := strings.TrimRight(line, " ") + if trimmed == "" { + continue + } + if printedN > 0 && n < len(lines)-1 { + w.print("\n") + w.printIndent(indent) + } + w.print(strings.ReplaceAll(trimmed, `"`, `\"`)) + printedN++ + } + w.print(`"""`) + w.print("\n") + default: + panic(n) + } +} + func (w *writer) printConstant(c *ast.Constant, indent int32) { switch n := c.Value.(type) { case *ast.Constant_Int: @@ -282,7 +319,7 @@ func (w *writer) printConstant(c *ast.Constant, indent int32) { str = `"""` } w.print(str) - w.print(n.Str) + w.print(strings.ReplaceAll(n.Str, `"`, `\"`)) w.print(str) default: @@ -291,9 +328,7 @@ func (w *writer) printConstant(c *ast.Constant, indent int32) { } func (w *writer) printComment(c *ast.Comment, indent int32) { - w.print("# ") - w.print(c.Text) - w.print("\n") + w.printCommentText(c.Text, 0) } func (w *writer) printCompare(c *ast.Compare, indent int32) {