diff --git a/api/analyzers/analyzer.py b/api/analyzers/analyzer.py index 64d4900..57e2009 100644 --- a/api/analyzers/analyzer.py +++ b/api/analyzers/analyzer.py @@ -133,7 +133,7 @@ def add_symbols(self, entity: Entity) -> None: pass @abstractmethod - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> Entity: """ Resolve a symbol to an entity. @@ -144,7 +144,7 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ symbol (Node): The symbol node. Returns: - list[Entity]: The resolved entities. + Entity: The entity. """ pass diff --git a/api/analyzers/java/analyzer.py b/api/analyzers/java/analyzer.py index 5269d69..a3bcce1 100644 --- a/api/analyzers/java/analyzer.py +++ b/api/analyzers/java/analyzer.py @@ -120,7 +120,7 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ res.append(file.entities[method_dec]) return res - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> Entity: if key in ["implement_interface", "base_class", "extend_interface", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, symbol) elif key in ["call"]: diff --git a/api/analyzers/kotlin/__init__.py b/api/analyzers/kotlin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/analyzers/kotlin/analyzer.py b/api/analyzers/kotlin/analyzer.py new file mode 100644 index 0000000..cd27576 --- /dev/null +++ b/api/analyzers/kotlin/analyzer.py @@ -0,0 +1,153 @@ +from pathlib import Path +from ...entities import * +from typing import Optional +from ..analyzer import AbstractAnalyzer + +from multilspy import SyncLanguageServer + +import tree_sitter_kotlin as tskotlin +from tree_sitter import Language, Node + +import logging +logger = logging.getLogger('code_graph') + +class KotlinAnalyzer(AbstractAnalyzer): + def __init__(self) -> None: + super().__init__(Language(tskotlin.language())) + + def add_dependencies(self, path: Path, files: list[Path]): + # For now, we skip dependency resolution for Kotlin + # In the future, this could parse build.gradle or pom.xml for Kotlin projects + pass + + def get_entity_label(self, node: Node) -> str: + if node.type == 'class_declaration': + # Check if it's an interface by looking for interface keyword + for child in node.children: + if child.type == 'interface': + return "Interface" + return "Class" + elif node.type == 'object_declaration': + return "Object" + elif node.type == 'function_declaration': + # Check if this is a method (inside a class) or a top-level function + parent = node.parent + if parent and parent.type == 'class_body': + return "Method" + return "Function" + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_name(self, node: Node) -> str: + if node.type in ['class_declaration', 'object_declaration']: + # Find the type_identifier child + for child in node.children: + if child.type == 'type_identifier': + return child.text.decode('utf-8') + elif node.type == 'function_declaration': + # Find the simple_identifier child + for child in node.children: + if child.type == 'simple_identifier': + return child.text.decode('utf-8') + raise ValueError(f"Cannot extract name from entity type: {node.type}") + + def get_entity_docstring(self, node: Node) -> Optional[str]: + if node.type in ['class_declaration', 'object_declaration', 'function_declaration']: + # Check for KDoc comment (/** ... */) before the node + if node.prev_sibling and node.prev_sibling.type == "multiline_comment": + comment_text = node.prev_sibling.text.decode('utf-8') + # Only return if it's a KDoc comment (starts with /**) + if comment_text.startswith('/**'): + return comment_text + return None + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_types(self) -> list[str]: + return ['class_declaration', 'object_declaration', 'function_declaration'] + + def add_symbols(self, entity: Entity) -> None: + if entity.node.type == 'class_declaration': + # Find superclass (extends) + superclass_query = self.language.query("(delegation_specifier (user_type (type_identifier) @superclass))") + superclass_captures = superclass_query.captures(entity.node) + if 'superclass' in superclass_captures: + for superclass in superclass_captures['superclass']: + entity.add_symbol("base_class", superclass) + + # Find interfaces (implements) + # In Kotlin, both inheritance and interface implementation use the same syntax + # We'll treat all as interfaces for now since Kotlin can only extend one class + interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))") + interface_captures = interface_query.captures(entity.node) + if 'interface' in interface_captures: + for interface in interface_captures['interface']: + entity.add_symbol("implement_interface", interface) + + elif entity.node.type == 'object_declaration': + # Objects can also have delegation specifiers + interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))") + interface_captures = interface_query.captures(entity.node) + if 'interface' in interface_captures: + for interface in interface_captures['interface']: + entity.add_symbol("implement_interface", interface) + + elif entity.node.type == 'function_declaration': + # Find function calls + query = self.language.query("(call_expression) @reference.call") + captures = query.captures(entity.node) + if 'reference.call' in captures: + for caller in captures['reference.call']: + entity.add_symbol("call", caller) + + # Find parameters with types + param_query = self.language.query("(parameter type: (user_type (type_identifier) @parameter))") + param_captures = param_query.captures(entity.node) + if 'parameter' in param_captures: + for parameter in param_captures['parameter']: + entity.add_symbol("parameters", parameter) + + # Find return type + return_type_query = self.language.query("(function_declaration type: (user_type (type_identifier) @return_type))") + return_type_captures = return_type_query.captures(entity.node) + if 'return_type' in return_type_captures: + for return_type in return_type_captures['return_type']: + entity.add_symbol("return_type", return_type) + + def is_dependency(self, file_path: str) -> bool: + # Check if file is in a dependency directory (e.g., build, .gradle cache) + return "build/" in file_path or ".gradle/" in file_path or "/cache/" in file_path + + def resolve_path(self, file_path: str, path: Path) -> str: + # For Kotlin, just return the file path as-is for now + return file_path + + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + for file, resolved_node in self.resolve(files, lsp, file_path, path, node): + type_dec = self.find_parent(resolved_node, ['class_declaration', 'object_declaration']) + if type_dec in file.entities: + res.append(file.entities[type_dec]) + return res + + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + # For call expressions, we need to extract the function name + if node.type == 'call_expression': + # Find the identifier being called + for child in node.children: + if child.type in ['simple_identifier', 'navigation_expression']: + for file, resolved_node in self.resolve(files, lsp, file_path, path, child): + method_dec = self.find_parent(resolved_node, ['function_declaration', 'class_declaration', 'object_declaration']) + if method_dec and method_dec.type in ['class_declaration', 'object_declaration']: + continue + if method_dec in file.entities: + res.append(file.entities[method_dec]) + break + return res + + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> Entity: + if key in ["implement_interface", "base_class", "parameters", "return_type"]: + return self.resolve_type(files, lsp, file_path, path, symbol) + elif key in ["call"]: + return self.resolve_method(files, lsp, file_path, path, symbol) + else: + raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/python/analyzer.py b/api/analyzers/python/analyzer.py index 7a99120..25b4ba6 100644 --- a/api/analyzers/python/analyzer.py +++ b/api/analyzers/python/analyzer.py @@ -115,7 +115,7 @@ def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_ res.append(file.entities[method_dec]) return res - def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> list[Entity]: + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> Entity: if key in ["base_class", "parameters", "return_type"]: return self.resolve_type(files, lsp, file_path, path, symbol) elif key in ["call"]: diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 4186f35..687b2c4 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -9,8 +9,8 @@ from .analyzer import AbstractAnalyzer # from .c.analyzer import CAnalyzer from .java.analyzer import JavaAnalyzer +from .kotlin.analyzer import KotlinAnalyzer from .python.analyzer import PythonAnalyzer -from .csharp.analyzer import CSharpAnalyzer from multilspy import SyncLanguageServer from multilspy.multilspy_config import MultilspyConfig @@ -26,7 +26,8 @@ # '.h': CAnalyzer(), '.py': PythonAnalyzer(), '.java': JavaAnalyzer(), - '.cs': CSharpAnalyzer()} + '.kt': KotlinAnalyzer(), + '.kts': KotlinAnalyzer()} class NullLanguageServer: def start_server(self): @@ -138,12 +139,14 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: lsps[".py"] = SyncLanguageServer.create(config, logger, str(path)) else: lsps[".py"] = NullLanguageServer() - if any(path.rglob('*.cs')): - config = MultilspyConfig.from_dict({"code_language": "csharp"}) - lsps[".cs"] = SyncLanguageServer.create(config, logger, str(path)) + if any(path.rglob('*.kt')) or any(path.rglob('*.kts')): + # For now, use NullLanguageServer for Kotlin as we need to set up kotlin-language-server + lsps[".kt"] = NullLanguageServer() + lsps[".kts"] = NullLanguageServer() else: - lsps[".cs"] = NullLanguageServer() - with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server(): + lsps[".kt"] = NullLanguageServer() + lsps[".kts"] = NullLanguageServer() + with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".kt"].start_server(), lsps[".kts"].start_server(): files_len = len(self.files) for i, file_path in enumerate(files): file = self.files[file_path] @@ -152,29 +155,25 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: entity.resolved_symbol(lambda key, symbol, fp=file_path: analyzers[fp.suffix].resolve_symbol(self.files, lsps[fp.suffix], fp, path, key, symbol)) for key, symbols in entity.symbols.items(): for symbol in symbols: - if len(symbol.resolved_symbol) == 0: - continue - resolved_symbol = next(iter(symbol.resolved_symbol)) if key == "base_class": - graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) + graph.connect_entities("EXTENDS", entity.id, symbol.id) elif key == "implement_interface": - graph.connect_entities("IMPLEMENTS", entity.id, resolved_symbol.id) + graph.connect_entities("IMPLEMENTS", entity.id, symbol.id) elif key == "extend_interface": - graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id) + graph.connect_entities("EXTENDS", entity.id, symbol.id) elif key == "call": - graph.connect_entities("CALLS", entity.id, resolved_symbol.id, {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")}) + graph.connect_entities("CALLS", entity.id, symbol.id) elif key == "return_type": - graph.connect_entities("RETURNS", entity.id, resolved_symbol.id) + graph.connect_entities("RETURNS", entity.id, symbol.id) elif key == "parameters": - graph.connect_entities("PARAMETERS", entity.id, resolved_symbol.id) + graph.connect_entities("PARAMETERS", entity.id, symbol.id) def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: self.first_pass(path, files, [], graph) self.second_pass(graph, files, path) def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None: - path = path.resolve() - files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.kt")) + list(path.rglob("*.kts")) # First pass analysis of the source code self.first_pass(path, files, ignore, graph) diff --git a/api/entities/entity.py b/api/entities/entity.py index 77f1cc9..e271693 100644 --- a/api/entities/entity.py +++ b/api/entities/entity.py @@ -1,24 +1,23 @@ from typing import Callable, Self from tree_sitter import Node -class Symbol: - def __init__(self, symbol: Node): - self.symbol = symbol - self.resolved_symbol = set() - - def add_resolve_symbol(self, resolved_symbol): - self.resolved_symbol.add(resolved_symbol) class Entity: def __init__(self, node: Node): self.node = node - self.symbols: dict[str, list[Symbol]] = {} + self.symbols: dict[str, list[Node]] = {} + self.resolved_symbols: dict[str, set[Self]] = {} self.children: dict[Node, Self] = {} def add_symbol(self, key: str, symbol: Node): if key not in self.symbols: self.symbols[key] = [] - self.symbols[key].append(Symbol(symbol)) + self.symbols[key].append(symbol) + + def add_resolved_symbol(self, key: str, symbol: Self): + if key not in self.resolved_symbols: + self.resolved_symbols[key] = set() + self.resolved_symbols[key].add(symbol) def add_child(self, child: Self): child.parent = self @@ -26,6 +25,7 @@ def add_child(self, child: Self): def resolved_symbol(self, f: Callable[[str, Node], list[Self]]): for key, symbols in self.symbols.items(): + self.resolved_symbols[key] = set() for symbol in symbols: - for resolved_symbol in f(key, symbol.symbol): - symbol.add_resolve_symbol(resolved_symbol) \ No newline at end of file + for resolved_symbol in f(key, symbol): + self.resolved_symbols[key].add(resolved_symbol) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 07e1db5..7b1a698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "tree-sitter-c>=0.24.1,<0.25.0", "tree-sitter-python>=0.25.0,<0.26.0", "tree-sitter-java>=0.23.5,<0.24.0", +"tree-sitter-kotlin>=1.1.0,<2.0.0", "tree-sitter-c-sharp>=0.23.1,<0.24.0", "flask>=3.1.0,<4.0.0", "python-dotenv>=1.0.1,<2.0.0",