Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,22 @@ def _extract_classes(
if not name:
return False

# Swift: detect the actual type keyword (class/struct/enum/actor/extension)
# and store it in extra["swift_kind"] for richer downstream analysis.
# Tree-sitter maps struct/enum/actor/extension all to class_declaration;
# protocol uses its own protocol_declaration node type.
extra: dict = {}
if language == "swift":
if child.type == "class_declaration":
_swift_keywords = {"class", "struct", "enum", "actor", "extension"}
for kw_child in child.children:
kw_text = kw_child.text.decode("utf-8", errors="replace")
if kw_text in _swift_keywords:
extra["swift_kind"] = kw_text
break
elif child.type == "protocol_declaration":
extra["swift_kind"] = "protocol"

node = NodeInfo(
kind="Class",
name=name,
Expand All @@ -1946,6 +1962,7 @@ def _extract_classes(
line_end=child.end_point[0] + 1,
language=language,
parent_name=enclosing_class,
extra=extra,
)
nodes.append(node)

Expand Down Expand Up @@ -3182,6 +3199,14 @@ def _get_name(self, node, language: str, kind: str) -> Optional[str]:
for child in node.children:
if child.type == "field_identifier":
return child.text.decode("utf-8", errors="replace")
# Swift extensions: name is inside user_type > type_identifier
# (e.g. `extension MyClass: Protocol { ... }`)
if language == "swift" and node.type == "class_declaration":
for child in node.children:
if child.type == "user_type":
for sub in child.children:
if sub.type == "type_identifier":
return sub.text.decode("utf-8", errors="replace")
# Most languages use a 'name' child
for child in node.children:
if child.type in (
Expand Down Expand Up @@ -3341,6 +3366,19 @@ def _get_bases(self, node, language: str, source: bytes) -> list[str]:
for sub in child.children:
if sub.type == "type_identifier":
bases.append(sub.text.decode("utf-8", errors="replace"))
elif language == "swift":
# Swift: class Foo: Bar, Baz { ... } / extension Foo: Protocol { ... }
# AST: inheritance_specifier > user_type > type_identifier
for child in node.children:
if child.type == "inheritance_specifier":
for sub in child.children:
if sub.type == "user_type":
for ident in sub.children:
if ident.type == "type_identifier":
bases.append(
ident.text.decode("utf-8", errors="replace")
)
break
return bases

def _extract_import(self, node, language: str, source: bytes) -> list[str]:
Expand Down
29 changes: 29 additions & 0 deletions tests/fixtures/sample.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,35 @@ class InMemoryRepo: UserRepository {
}
}

enum Direction: String {
case north
case south
case east
case west
}

actor DataStore {
private var cache: [String: User] = [:]

func get(_ key: String) -> User? {
return cache[key]
}

func set(_ key: String, user: User) {
cache[key] = user
}
}

extension InMemoryRepo: CustomStringConvertible {
var description: String {
return "InMemoryRepo with \(users.count) users"
}

func clear() {
users.removeAll()
}
}

func createUser(repo: UserRepository, name: String, email: String) -> User {
let user = User(id: 1, name: name, email: email)
repo.save(user)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_multilang.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,62 @@ def test_detects_language(self):
def test_finds_classes(self):
classes = [n for n in self.nodes if n.kind == "Class"]
names = {c.name for c in classes}
assert "User" in names or "InMemoryRepo" in names
assert "User" in names
assert "InMemoryRepo" in names

def test_finds_functions(self):
funcs = [n for n in self.nodes if n.kind == "Function"]
names = {f.name for f in funcs}
assert "createUser" in names or "findById" in names or "save" in names

def test_finds_enum(self):
classes = [n for n in self.nodes if n.kind == "Class"]
names = {c.name for c in classes}
assert "Direction" in names

def test_finds_actor(self):
classes = [n for n in self.nodes if n.kind == "Class"]
names = {c.name for c in classes}
assert "DataStore" in names

def test_finds_extension(self):
"""Extensions should be detected and linked to the extended type."""
classes = [n for n in self.nodes if n.kind == "Class"]
# Extension of InMemoryRepo should produce a Class node named InMemoryRepo
# with swift_kind == "extension"
ext_nodes = [c for c in classes if c.extra.get("swift_kind") == "extension"]
assert len(ext_nodes) >= 1
assert ext_nodes[0].name == "InMemoryRepo"

def test_finds_protocol(self):
classes = [n for n in self.nodes if n.kind == "Class"]
names = {c.name for c in classes}
assert "UserRepository" in names

def test_swift_kind_extra(self):
"""Each Swift type should have the correct swift_kind in extra."""
classes = {n.name: n for n in self.nodes if n.kind == "Class"}
assert classes["User"].extra.get("swift_kind") == "struct"
assert classes["Direction"].extra.get("swift_kind") == "enum"
assert classes["DataStore"].extra.get("swift_kind") == "actor"
assert classes["UserRepository"].extra.get("swift_kind") == "protocol"
# InMemoryRepo appears twice (class + extension); check at least one is "class"
repo_nodes = [n for n in self.nodes if n.kind == "Class" and n.name == "InMemoryRepo"]
kinds = {n.extra.get("swift_kind") for n in repo_nodes}
assert "class" in kinds
assert "extension" in kinds

def test_inheritance_edges(self):
"""Swift inheritance / conformance should produce INHERITS edges."""
inherits = [e for e in self.edges if e.kind == "INHERITS"]
targets = {e.target for e in inherits}
# InMemoryRepo: UserRepository
assert "UserRepository" in targets
# Direction: String
assert "String" in targets
# extension InMemoryRepo: CustomStringConvertible
assert "CustomStringConvertible" in targets


class TestScalaParsing:
def setup_method(self):
Expand Down