@@ -303,23 +303,11 @@ def _cosine_similarity(self, words1: List[str], words2: List[str]) -> float:
303303 return (dot_product / (magnitude1 * magnitude2 )) * 100
304304
305305 def _compute_structural_metrics (self , original : str , generated : str ) -> StructuralMetrics :
306- """Compute structural metrics."""
306+ """Compute structural metrics using AST when possible, regex as fallback ."""
307307 metrics = StructuralMetrics ()
308308
309- # Count elements
310- def count_elements (code : str ) -> Dict [str , int ]:
311- return {
312- 'classes' : len (re .findall (r'^class\s+\w+' , code , re .MULTILINE )),
313- 'functions' : len (re .findall (r'^(?:async\s+)?def\s+\w+' , code , re .MULTILINE )),
314- 'methods' : len (re .findall (r'^\s+(?:async\s+)?def\s+\w+' , code , re .MULTILINE )),
315- 'imports' : len (re .findall (r'^(?:from|import)\s+' , code , re .MULTILINE )),
316- # Capture both annotated attributes and simple assignments.
317- # This is still heuristic, but avoids undercounting common code.
318- 'attributes' : len (re .findall (r'^\s+\w+\s*(?::\s*[^=\n]+)?\s*=' , code , re .MULTILINE )),
319- }
320-
321- orig = count_elements (original )
322- gen = count_elements (generated )
309+ orig = self ._count_elements_ast (original )
310+ gen = self ._count_elements_ast (generated )
323311
324312 metrics .classes_original = orig ['classes' ]
325313 metrics .classes_generated = gen ['classes' ]
@@ -341,15 +329,15 @@ def count_elements(code: str) -> Dict[str, int]:
341329 metrics .attributes_generated = gen ['attributes' ]
342330 metrics .attributes_match = orig ['attributes' ] == gen ['attributes' ]
343331
344- # Structural score
345- matches = sum ([
346- metrics . classes_match ,
347- metrics . functions_match ,
348- metrics . methods_match ,
349- metrics . imports_match ,
350- metrics . attributes_match ,
351- ] )
352- metrics .structural_score = (matches / 5 ) * 100
332+ # Ratio-based structural score (partial credit instead of binary)
333+ total = 0.0
334+ for key in ( 'classes' , 'functions' , 'methods' , 'imports' , 'attributes' ):
335+ ov , gv = orig [ key ], gen [ key ]
336+ if ov == 0 and gv == 0 :
337+ total += 1.0
338+ elif max ( ov , gv ) > 0 :
339+ total += min ( ov , gv ) / max ( ov , gv )
340+ metrics .structural_score = (total / 5 ) * 100
353341
354342 # Element coverage
355343 total_orig = sum (orig .values ())
@@ -359,6 +347,59 @@ def count_elements(code: str) -> Dict[str, int]:
359347
360348 return metrics
361349
350+ @staticmethod
351+ def _count_elements_ast (code : str ) -> Dict [str , int ]:
352+ """Count structural elements using Python AST, with regex fallback."""
353+ import ast as _ast
354+
355+ try :
356+ tree = _ast .parse (code )
357+ except SyntaxError :
358+ # Fallback to regex for unparseable code
359+ return {
360+ 'classes' : len (re .findall (r'^class\s+\w+' , code , re .MULTILINE )),
361+ 'functions' : len (re .findall (r'^(?:async\s+)?def\s+\w+' , code , re .MULTILINE )),
362+ 'methods' : len (re .findall (r'^\s+(?:async\s+)?def\s+\w+' , code , re .MULTILINE )),
363+ 'imports' : len (re .findall (r'^(?:from|import)\s+' , code , re .MULTILINE )),
364+ 'attributes' : len (re .findall (r'^\s+\w+\s*(?::\s*[^=\n]+)?\s*=' , code , re .MULTILINE )),
365+ }
366+
367+ classes = 0
368+ functions = 0
369+ methods = 0
370+ imports = 0
371+ attributes = 0
372+
373+ for node in _ast .walk (tree ):
374+ if isinstance (node , _ast .ClassDef ):
375+ classes += 1
376+ # Count methods inside classes
377+ for item in node .body :
378+ if isinstance (item , (_ast .FunctionDef , _ast .AsyncFunctionDef )):
379+ methods += 1
380+ # Count class-level attributes (annotated or assigned)
381+ elif isinstance (item , (_ast .Assign , _ast .AnnAssign )):
382+ attributes += 1
383+ elif isinstance (node , (_ast .FunctionDef , _ast .AsyncFunctionDef )):
384+ # Only count as top-level function if not inside a class
385+ # (methods already counted above)
386+ pass
387+ elif isinstance (node , (_ast .Import , _ast .ImportFrom )):
388+ imports += 1
389+
390+ # Count top-level functions (not methods)
391+ for node in _ast .iter_child_nodes (tree ):
392+ if isinstance (node , (_ast .FunctionDef , _ast .AsyncFunctionDef )):
393+ functions += 1
394+
395+ return {
396+ 'classes' : classes ,
397+ 'functions' : functions ,
398+ 'methods' : methods ,
399+ 'imports' : imports ,
400+ 'attributes' : attributes ,
401+ }
402+
362403 def _compute_semantic_metrics (self , original : str , generated : str ) -> SemanticMetrics :
363404 """Compute semantic preservation metrics."""
364405 metrics = SemanticMetrics ()
0 commit comments