From cfc3608bed845cab23f0cce0a61649f70311e4f7 Mon Sep 17 00:00:00 2001 From: Atharva Singh Date: Sun, 12 Apr 2026 02:26:54 +0530 Subject: [PATCH] fix: parse python relative import shorthand --- internal/graph/graph_test.go | 19 +++++++++++++++++++ internal/parser/python_test.go | 25 ++++++++++++++++++++++++- internal/parser/ts_python.go | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go index 8d6950c..eac88f5 100644 --- a/internal/graph/graph_test.go +++ b/internal/graph/graph_test.go @@ -222,6 +222,25 @@ func TestGoImportWithDotsNotMangled(t *testing.T) { } } +// TestPythonRelativeImportShorthand verifies that "from . import module" style +// imports resolve to dependency edges. +func TestPythonRelativeImportShorthand(t *testing.T) { + files := []*parser.FileInfo{ + {Path: "pkg/api/views.py", Language: "python", Imports: []string{".models", "..shared"}, LineCount: 20}, + {Path: "pkg/api/models/user.py", Language: "python", LineCount: 10}, + {Path: "pkg/shared/util.py", Language: "python", LineCount: 10}, + } + + g := Build(files, BuildOptions{MaxDepth: 4}) + + apiMod := g.Module("pkg/api") + if apiMod == nil { + t.Fatal("expected module pkg/api to exist") + } + assertStringSliceContains(t, apiMod.DependsOn, "pkg/api/models") + assertStringSliceContains(t, apiMod.DependsOn, "pkg/shared") +} + // moduleNames is a helper to extract names for error messages. func moduleNames(mods []*Module) []string { names := make([]string, len(mods)) diff --git a/internal/parser/python_test.go b/internal/parser/python_test.go index 5b92b0e..60b9b81 100644 --- a/internal/parser/python_test.go +++ b/internal/parser/python_test.go @@ -6,7 +6,6 @@ import ( "testing" ) - func TestPythonParserParse_Imports(t *testing.T) { p := &TreeSitterParser{} @@ -126,3 +125,27 @@ func TestPythonParserParse_LineCount(t *testing.T) { t.Errorf("LineCount = %d, want 4", info.LineCount) } } + +func TestPythonParserParse_RelativeImportShorthand(t *testing.T) { + p := &TreeSitterParser{} + + content := []byte("from . import models\nfrom .. import utils\n") + info, err := p.Parse("pkg/api/views.py", content) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + wantImports := []string{".models", "..utils"} + for _, want := range wantImports { + found := false + for _, got := range info.Imports { + if got == want { + found = true + break + } + } + if !found { + t.Errorf("import %q not found in %v", want, info.Imports) + } + } +} diff --git a/internal/parser/ts_python.go b/internal/parser/ts_python.go index 61ac172..298cb2e 100644 --- a/internal/parser/ts_python.go +++ b/internal/parser/ts_python.go @@ -31,7 +31,11 @@ func extractPython(root *gts.Node, lang *gts.Language, src []byte, path string) if mod != nil { name := nodeText(mod, src) if name != "" { - info.Imports = dedup(append(info.Imports, name)) + if strings.Trim(name, ".") == "" { + info.Imports = dedup(append(info.Imports, extractPyRelativeFromImports(node, lang, src, name)...)) + } else { + info.Imports = dedup(append(info.Imports, name)) + } } } return gts.WalkSkipChildren @@ -92,6 +96,34 @@ func extractPython(root *gts.Node, lang *gts.Language, src []byte, path string) return info } +// extractPyRelativeFromImports expands shorthand relative imports such as +// "from . import models" into module paths like ".models". +func extractPyRelativeFromImports(node *gts.Node, lang *gts.Language, src []byte, prefix string) []string { + var imports []string + for i := 0; i < node.ChildCount(); i++ { + child := node.Child(i) + if child == nil { + continue + } + switch child.Type(lang) { + case "identifier", "dotted_name": + name := nodeText(child, src) + if name != "" { + imports = append(imports, prefix+name) + } + case "aliased_import": + nameNode := child.ChildByFieldName("name", lang) + if nameNode != nil { + name := nodeText(nameNode, src) + if name != "" { + imports = append(imports, prefix+name) + } + } + } + } + return imports +} + // extractPyImportNames handles import_statement nodes and appends module names // to info.Imports. Handles: import os, import os.path, import os as o. func extractPyImportNames(node *gts.Node, lang *gts.Language, src []byte, info *FileInfo) {