Skip to content

feat(evaluation): add VLMMetrics#545

Open
davidberenstein1957 wants to merge 40 commits intomainfrom
feat/metrics-vlm-support
Open

feat(evaluation): add VLMMetrics#545
davidberenstein1957 wants to merge 40 commits intomainfrom
feat/metrics-vlm-support

Conversation

@davidberenstein1957
Copy link
Copy Markdown
Member

Add ImageRewardMetric for evaluating image-text alignment using ImageReward library.

@davidberenstein1957 davidberenstein1957 changed the title feat(evaluation): add ImageRewardMetric feat(evaluation): add VLMMetrics Feb 21, 2026
Copy link
Copy Markdown
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much David! I have asked some questions because I am not really familiar with litellm and the requirements using this specific framework brings. I think the metric classes themselves look really good, just needs some tweaks here and there! And really apologies for the delay in reviewing 💞

@codacy-production
Copy link
Copy Markdown

codacy-production bot commented Apr 1, 2026

Not up to standards ⛔

🔴 Issues 1 critical · 57 high · 21 medium · 1 minor

Alerts:
⚠ 80 issues (≤ 0 issues of at least minor severity)

Results:
80 new issues

Category Results
UnusedCode 1 medium
Security 57 high
CodeStyle 1 minor
Complexity 1 critical
20 medium

View in Codacy

🟢 Metrics 401 complexity · 83 duplication

Metric Results
Complexity 401
Duplication 83

View in Codacy

TIP This summary will be updated as you push new changes. Give us feedback

Copy link
Copy Markdown
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey David, great updates on these! I left some specific comments below, but the general theme is that there’s a bit of a mismatch between the current implementation and the original research methodologies.

The VLM infrastructure you’ve built is a great foundation, but we’re currently missing some of the logic that actually defines these benchmarks. Without those steps, our results might be hard to compare with the official papers. Let me know what you think!!

@davidberenstein1957 davidberenstein1957 force-pushed the feat/metrics-vlm-support branch from f1d0d73 to 33d9135 Compare April 5, 2026 05:08
… support

- Add vlm_base.py with LitellmVLM and TransformersVLM
- Add metrics_vlm.py with VLM-based metrics:
  - VQAMetric
  - AlignmentScoreMetric
  - ImageEditScoreMetric
  - QAAccuracyMetric
  - TextScoreMetric
  - VieScoreMetric
- Uses litellm (default gpt-4o) or local transformers models
ARNIQA is not available in torchmetrics 1.7.4. Implementing
simplified version with optional pretrained weight loading.
- Use scores: List[float] instead of tensor total/count
- Add default_call_type and runs_on attributes
- Match SharpnessMetric pattern
The async version was returning a coroutine instead of the actual
response, causing all VLM metrics to silently fail.
- Add pydantic models for structured output (VQAnswer, ScoreOutput)
- LitellmVLM: Use response_format parameter for stable outputs
- TransformersVLM: Add outlines support for constrained decoding
- Add structured_output flag to all VLM metrics
- Add proper paper references (VQAScore, VieScore)
- Add pydantic>=2.0.0 to dependencies
- Add docstrings to update/compute methods
- Fix type hints
- Add ruff fixes
- Add PIL import at top
- Fix type hints
- D205 docstring issues are from multi-line examples
The metrics_vlm module uses a different docstring pattern for VLM
parameters that doesn't fit numpydoc's PR01 check. Skip this check
for the new VLM metrics.
- Added detailed parameter descriptions to VQAnswer, ScoreOutput, and various metric classes in metrics_vlm.py.
- Updated docstrings in base classes of vlm_base.py to include parameter details and return types.
- Improved clarity and consistency across all metric-related docstrings.
- Added new metrics: AlignmentScoreMetric, ImageEditScoreMetric, QAAccuracyMetric, TextScoreMetric, VieScoreMetric, and VQAMetric for comprehensive evaluation of image-text alignment and quality.
- Implemented integration test script for VLM metrics, allowing testing against both Litellm and Transformers backends.
- Updated pyproject.toml to reflect new dependencies and changes in optional dependencies.
- Added documentation for prompt comparisons between Pruna and InferBench implementations.
…m docstrings

- VieScore: docstring arXiv:2312.14867, TIGER-AI-Lab/VIEScore
- Image Edit Score: docstring EditScore, ADIEE
- VQA: docstring arXiv:2404.01291, use_probability=True default
- vlm_base: full Parameters/Returns for score(), _score_with_logprobs

Made-with: Cursor
- Added docstrings to the update and compute methods for AlignmentScoreMetric, ImageEditScoreMetric, QAAccuracyMetric, TextScoreMetric, VieScoreMetric, and VQAMetric to improve clarity on their functionality.
- Updated the test suite to ensure compatibility with new metric requirements.
- Enhanced the type hints for the response_format parameter in BaseVLM, LitellmVLM, and TransformersVLM classes to include Literal types ("integer", "yes_no") alongside the existing Type[BaseModel].
- Updated docstrings to reflect the new response_format options, improving clarity on expected input types and usage.
- Introduced a new variable `use_pydantic` to clarify the condition for checking if the content result is an instance of the specified response_format type.
- Improved code readability by breaking down the condition into a more understandable format.
- Updated the response_format parameter in BaseVLM, LitellmVLM, and TransformersVLM classes to include "json" as a valid option alongside existing types.
- Adjusted docstrings to reflect the new response_format options for improved clarity on expected input types.
- Included the "pruna[evaluation]" package in the development dependencies for enhanced evaluation capabilities.
- Updated the `vlm_base.py` file to suppress type checking for model device assignment.
- Cleaned up the test suite by removing unnecessary imports and conditions related to VLM metrics.
…s metric classes

- Refactored docstrings for update and compute methods in AlignmentScoreMetric, ImageEditScoreMetric, QAAccuracyMetric, TextScoreMetric, VieScoreMetric, and VQAMetric to enhance clarity and consistency.
- Updated parameter descriptions in the VLM utility classes to provide clearer documentation for structured outputs.
- Reformatted import statements in several metric files for improved readability.
- Replaced YesNoAnswer and ScoreOutput with VQAnswer and FloatOutput in multiple metric classes for consistency in structured outputs.
- Enhanced the metric_vlm_utils.py file by introducing get_answer_from_response and get_text_from_response functions for better response handling.
- Updated the TextScoreMetric to accept List[str] for ground truth, improving flexibility in input types.
- Adjusted the update method in the test suite to accommodate new metric requirements and ensure compatibility with structured outputs.
…ask creation

- Replaced `prompt_collate` with `prompt_with_auxiliaries_collate` in dataset configurations to support auxiliary data.
- Removed the old `prompt_collate` function and updated related metric classes to handle inputs with auxiliary information.
- Introduced a new class method `from_benchmark` in the Task class to facilitate task creation from benchmark names, improving usability and integration with the BenchmarkRegistry.
- Updated various metrics to utilize the new input structure, ensuring compatibility with benchmarks that provide auxiliary data.
…est-only benchmarks

- Changed the seed parameter in PrunaDataModule and various dataset setup functions to accept None, allowing for more flexible seed management.
- Introduced a warning mechanism for test-only benchmarks to inform users when the seed is ignored, ensuring clarity in dataset behavior.
- Updated docstrings to reflect the new optional seed parameter and its implications for dataset setup.
- Added multiple OneIG dataset setups for anime stylization, general objects, knowledge reasoning, multilingualism, portraits, and text rendering.
- Updated the dataset initialization to include new dataset configurations in the `__init__.py` file.
- Introduced new benchmark classes for OneIG subsets in the benchmarks registry, ensuring comprehensive evaluation capabilities.
- Enhanced metric classes to support new evaluation metrics and updated the handling of device compatibility across metrics.
- Added tests for OneIG dataset loading and processing to ensure functionality and correctness.
…tric

- Added OneIGTextScoreMetric for OCR-based composite scoring, providing a higher-is-better metric.
- Updated TextScoreMetric to include descriptive registry aliases and improved docstring clarity.
- Enhanced initialization parameters for both metrics to support better configuration and compatibility.
- Added tests for OneIGTextScoreMetric to validate functionality and ensure correct behavior with ground truth comparisons.
- Introduced OneIGAlignmentMetric to implement alignment scoring with dependency masking, enhancing the evaluation of question dependencies.
- Added utility functions for applying dependency masks and aggregating scores per grid cell.
- Updated the metrics registry to include the new metric and modified the __init__.py file accordingly.
- Implemented unit tests to validate the functionality of the OneIG alignment metric and its dependency handling.
- Introduced the OneIGReasoningMetric for evaluating text-image similarity using LLM2CLIP, enhancing the scoring capabilities for knowledge reasoning tasks.
- Updated the OneIG dataset setup to support reasoning language selection (EN/ZH) for improved flexibility in handling multilingual datasets.
- Added new benchmark classes for OneIG subsets, ensuring comprehensive evaluation across various categories.
- Enhanced the metrics registry to include the new reasoning metric and updated related utility functions for better integration.
- Implemented tests to validate the functionality of the OneIG reasoning metric and its interaction with the dataset.
@davidberenstein1957 davidberenstein1957 force-pushed the feat/metrics-vlm-support branch from 33d9135 to 8de57ad Compare April 5, 2026 05:44
…resh benchmark docs

- Task.from_benchmark: special-case GenEval with qa_accuracy + clip_score
- Benchmarks: GenEval/Long Text/GEdit descriptions; vie_score metric id
- Add test for GenEval task metric wiring

Made-with: Cursor
…ured_output

- get_vlm passes structured_output to TransformersVLM as use_outlines
- Remove use_outlines from VLM metrics and task routing kwargs
- Minor test/docstring updates

Made-with: Cursor
- Rename metric_vlm_utils to vlm_utils; add score parsing tests
- pyproject: core tqdm/realesrgan; simplify torch routing kwargs type
- BaseMetric runs_on includes mps; drop redundant runs_on; unify vlm_kwargs docs
- img_edit_score uses get_score_from_response for structured outputs

Made-with: Cursor
- Join LongText-Bench list text_content before OCR scoring
- Reduce datamodule benchmark tests (category smoke, prompt aux merge)
- Trim VLM metric tests; drop slow mark on mocked GenEval task test

Made-with: Cursor
- Added detailed docstrings for class methods to clarify functionality and usage.
- Simplified error messages for unsupported model configurations.
- Improved file handling for loading configuration files with explicit encoding.
- Streamlined code formatting for better readability and consistency.

Made-with: Cursor
- Added detailed docstrings for functions and classes to clarify their purpose and usage.
- Updated version check functions to specify the required version of the `transformers` package.
- Introduced new classes for modified Llama attention and decoder layers to support bidirectional encoding.
- Improved error handling in the Llama encoder model for unsupported transformer versions.

Made-with: Cursor
…rkRegistry

- Updated type hints in the BenchmarkRegistry and LLM2Vec classes for better clarity and compatibility.
- Enhanced the batch_to_device function to accept both device strings and device types.
- Improved handling of optional parameters in LLM2Vec methods to prevent potential errors.
- Added type casting for better type safety in the bidirectional Llama model.

Made-with: Cursor
- Updated import path for LlamaBiModel to reflect new module structure.
- Improved docstrings across various classes and methods to provide clearer descriptions and parameter details.
- Ensured consistency in return type annotations and parameter specifications for better code readability.

Made-with: Cursor
- Renamed and refactored dataset setup functions for clarity and consistency, including the introduction of `_setup_oneig_subset_with_fixed_category`.
- Added new functions for loading specific OneIG datasets with fixed categories, improving usability.
- Introduced a new module for VLM benchmark integration, providing shared helpers and metrics for evaluation.
- Enhanced docstrings across various functions to clarify parameters and return types, ensuring better documentation and understanding.

Made-with: Cursor
@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Apr 10, 2026

@cursor review

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

Fix All in Cursor

Bugbot Autofix prepared fixes for all 3 issues found in the latest run.

  • ✅ Fixed: aggregation parameter stored but never applied in scoring
    • QAAccuracyMetric now validates and applies aggregation, using strict binary scoring for all_or_nothing instead of always averaging per-question scores.
  • ✅ Fixed: Chinese language heuristic misclassifies EN-only rows
    • The OneIG language heuristic now only infers Chinese from prompt-only rows when actual CJK characters are present, preventing EN prompt-only rows from being routed to _zh Q_D files.
  • ✅ Fixed: Text score uses unnormalized distance favoring short texts
    • TextScoreMetric now stores normalized Levenshtein distance by dividing by ground-truth character length so the reported mean matches character error rate behavior across varying text lengths.

Create PR

Or push these changes by commenting:

@cursor push 6e88aefdac
Preview (6e88aefdac)
diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py
--- a/src/pruna/data/datasets/prompt.py
+++ b/src/pruna/data/datasets/prompt.py
@@ -139,11 +139,15 @@
     lang = row.get("language") or row.get("lang")
     if isinstance(lang, str) and lang.lower() in {"zh", "zh-cn", "zh_cn", "chinese", "cn"}:
         return True
-    if row.get("prompt_zh"):
+    if row.get("prompt_zh") or row.get("prompt_cn"):
         return True
     prompt = row.get("prompt")
     prompt_en = row.get("prompt_en")
-    return bool(prompt and not (isinstance(prompt_en, str) and prompt_en.strip()))
+    if not (isinstance(prompt, str) and prompt.strip()):
+        return False
+    if isinstance(prompt_en, str) and prompt_en.strip():
+        return False
+    return any("\u4e00" <= ch <= "\u9fff" for ch in prompt)
 
 
 def _oneig_qd_prefix(row: dict) -> str:

diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py
--- a/src/pruna/evaluation/metrics/metric_qa_accuracy.py
+++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py
@@ -105,6 +105,11 @@
         self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
         self.add_state("scores", [])
         self.aggregation = kwargs.pop("aggregation", "mean")
+        if self.aggregation not in {"mean", "all_or_nothing"}:
+            raise ValueError(
+                "qa_accuracy aggregation must be one of {'mean', 'all_or_nothing'}. "
+                f"Got: {self.aggregation!r}."
+            )
 
     def _extract_questions(self, gt: Any, n: int) -> List[List[str]]:
         if isinstance(gt, (list, tuple)) and len(gt) >= n:
@@ -151,7 +156,10 @@
                 ["Yes"] * len(questions),
                 response_format=self.response_format,
             )
-            score = float(np.mean(scores))
+            if self.aggregation == "all_or_nothing":
+                score = float(all(float(s) == 1.0 for s in scores))
+            else:
+                score = float(np.mean(scores))
             self.scores.append(score)
 
     def compute(self) -> MetricResult:

diff --git a/src/pruna/evaluation/metrics/metric_text_score.py b/src/pruna/evaluation/metrics/metric_text_score.py
--- a/src/pruna/evaluation/metrics/metric_text_score.py
+++ b/src/pruna/evaluation/metrics/metric_text_score.py
@@ -172,7 +172,7 @@
 @MetricRegistry.register("text_score")
 class TextScoreMetric(_BaseVLMOCRTextMetric):
     """
-    OCR then mean Levenshtein distance to ground truth (lower is better).
+    OCR then mean normalized Levenshtein distance (character error rate, lower is better).
 
     Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy).
 
@@ -240,7 +240,8 @@
     def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
         norm_gt = normalize_text_simple(text_gt)
         norm_ocr = normalize_text_simple(ocr_text)
-        self.scores.append(levenshtein(norm_ocr, norm_gt))
+        gt_len = max(len(norm_gt), 1)
+        self.scores.append(float(levenshtein(norm_ocr, norm_gt) / gt_len))
 
     def _compute_result_value(self) -> float:
         if not self.scores:

diff --git a/tests/data/test_oneig_loader.py b/tests/data/test_oneig_loader.py
--- a/tests/data/test_oneig_loader.py
+++ b/tests/data/test_oneig_loader.py
@@ -34,6 +34,18 @@
     assert prompt_mod._oneig_qd_prefix(row) == "anime_zh"
 
 
+def test_oneig_qd_prefix_prompt_only_en_row_stays_en() -> None:
+    """Prompt-only EN rows must not be misclassified as Chinese."""
+    row = {
+        "category": "General_Object",
+        "id": "001",
+        "prompt": "a red apple on a table",
+        "prompt_en": "",
+        "class": "None",
+    }
+    assert prompt_mod._oneig_qd_prefix(row) == "object"
+
+
 def test_to_oneig_record_multilingualism_fills_questions() -> None:
     """Synthetic Multilingualism row resolves Q_D from merged index."""
     qb = {"multilingualism_zh_000": {"questions": {"1": "现场是不是颁奖典礼?"}, "dependencies": {"1": [0]}}}

diff --git a/tests/evaluation/test_vlm_metrics.py b/tests/evaluation/test_vlm_metrics.py
--- a/tests/evaluation/test_vlm_metrics.py
+++ b/tests/evaluation/test_vlm_metrics.py
@@ -146,6 +146,22 @@
 
 
 @pytest.mark.cpu
+def test_qa_accuracy_aggregation_modes() -> None:
+    mock_vlm = MagicMock(spec=BaseVLM)
+    mock_vlm.score.return_value = [1.0, 0.0]
+    images = _dummy_image(batch=1)
+    aux = [{"questions": {"1": "Q1", "2": "Q2"}}]
+
+    mean_metric = QAAccuracyMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", aggregation="mean")
+    mean_metric.update(["a prompt"], aux, images)
+    assert mean_metric.compute().result == pytest.approx(0.5)
+
+    strict_metric = QAAccuracyMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", aggregation="all_or_nothing")
+    strict_metric.update(["a prompt"], aux, images)
+    assert strict_metric.compute().result == pytest.approx(0.0)
+
+
+@pytest.mark.cpu
 def test_get_vlm_returns_custom() -> None:
     custom = MagicMock(spec=BaseVLM)
     out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o")
@@ -183,6 +199,19 @@
 
 
 @pytest.mark.cpu
+def test_text_score_uses_normalized_edit_distance() -> None:
+    mock_vlm = MagicMock(spec=BaseVLM)
+    mock_vlm.generate.side_effect = [["abxde"], ["ax"]]
+    metric = TextScoreMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu")
+
+    metric.update(["p1"], ["abcde"], _dummy_image(batch=1))
+    metric.update(["p2"], ["ab"], _dummy_image(batch=1))
+
+    assert metric.scores == pytest.approx([0.2, 0.5])
+    assert metric.compute().result == pytest.approx(0.35)
+
+
+@pytest.mark.cpu
 def test_text_score_registry_aliases() -> None:
     from pruna.evaluation.metrics.registry import MetricRegistry

This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

Reviewed by Cursor Bugbot for commit 7435679. Configure here.

["Yes"] * len(questions),
response_format=self.response_format,
)
score = float(np.mean(scores))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aggregation parameter stored but never applied in scoring

High Severity

QAAccuracyMetric accepts and stores self.aggregation (e.g. "all_or_nothing" for GenEval) but never reads it. The update method always uses np.mean(scores) on line 154, giving partial credit to every image. For GenEval, where Task.from_benchmark explicitly passes aggregation="all_or_nothing", this produces inflated scores instead of the official binary pass/fail (1 only if every atomic question passes, 0 otherwise).

Additional Locations (2)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7435679. Configure here.

return True
prompt = row.get("prompt")
prompt_en = row.get("prompt_en")
return bool(prompt and not (isinstance(prompt_en, str) and prompt_en.strip()))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chinese language heuristic misclassifies EN-only rows

Medium Severity

_oneig_alignment_language_zh falls through to a heuristic on line 146 that returns True (Chinese) when prompt is non-empty but prompt_en is absent or empty. Rows in the EN config (OneIG-Bench) that use prompt as their primary text field without a separate prompt_en column would be misclassified as Chinese, causing _oneig_qd_prefix to select *_zh question-dependency files instead of the English ones.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7435679. Configure here.

def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
norm_gt = normalize_text_simple(text_gt)
norm_ocr = normalize_text_simple(ocr_text)
self.scores.append(levenshtein(norm_ocr, norm_gt))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Text score uses unnormalized distance favoring short texts

Medium Severity

TextScoreMetric._accumulate_sample stores raw Levenshtein edit distance without normalizing by text length. Since the metric is lower_is_better, longer ground-truth texts inherently produce higher (worse) scores than shorter texts with the same number of errors, making scores incomparable across samples of different length. The benchmark description mentions "mean character error rate" but the implementation returns raw edit counts.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7435679. Configure here.

@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Apr 10, 2026

The code change (>5K) is too large for me :( Could you point me to some place where my review can be particularly useful ? I can definitely dedicate some time to review some files, but not all 😓

@davidberenstein1957
Copy link
Copy Markdown
Member Author

Hi @llcnt, if you have time. You can take a look at the LLM2CLIP implementation :)

@davidberenstein1957
Copy link
Copy Markdown
Member Author


Code review

Found 2 issues:

  1. qa_accuracy ignores aggregation (including GenEval all_or_nothing). The metric stores aggregation in __init__ but update always uses np.mean(scores) per image and compute always uses np.mean(self.scores), so partial credit remains and the final score is not strict per-image all-or-nothing. This conflicts with the GenEval benchmark text and Task.from_benchmark wiring that pass aggregation=\"all_or_nothing\".

self.aggregation = kwargs.pop("aggregation", "mean")
def _extract_questions(self, gt: Any, n: int) -> List[List[str]]:
if isinstance(gt, (list, tuple)) and len(gt) >= n:
out = []
for i in range(n):
v = gt[i]
if isinstance(v, dict) and "questions" in v:
qs = v["questions"]
out.append(list(qs.values()) if isinstance(qs, dict) else list(qs))
else:
out.append([])
return out
return [[]] * n
def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None:
"""
Update the metric with new batch data.
Parameters
----------
x : List[Any] | torch.Tensor
The input data.
gt : torch.Tensor
The ground truth (questions per image).
outputs : torch.Tensor
The output images.
"""
inputs = metric_data_processor(x, gt, outputs, self.call_type)
images = _process_images(inputs[0])
auxiliaries = inputs[1] if len(inputs) > 1 else []
questions_per_image = self._extract_questions(auxiliaries, len(images))
for i, image in enumerate(images):
questions = questions_per_image[i] if i < len(questions_per_image) else []
if not questions:
aux = auxiliaries[i] if i < len(auxiliaries) else {}
raise ValueError(
"qa_accuracy requires 'questions' in auxiliaries. "
"Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). "
f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}."
)
scores = self.vlm.score(
[image] * len(questions),
questions,
["Yes"] * len(questions),
response_format=self.response_format,
)
score = float(np.mean(scores))
self.scores.append(score)
def compute(self) -> MetricResult:
"""
Compute the QA accuracy score.
Returns
-------
MetricResult
The mean QA accuracy across all updates.
"""
if not self.scores:
return MetricResult(self.metric_name, self.__dict__, 0.0)
return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores)))

Benchmark(
name="GenEval",
description=(
"Compositional text-to-image benchmark with 6 categories: single object, two object, "
"counting, colors, position, color attributes. Uses atomic yes/no questions per prompt; "
"``Task.from_benchmark`` wires ``qa_accuracy`` with strict per-image aggregation "
"(all questions must pass) plus ``clip_score``. For holistic VQAScore-style scoring "
"use GenAI Bench with ``vqa``."
),

  1. _extract_questions uses [[]] * n, which aliases the same inner list n times. If the fallback branch runs with n > 1, all indices share one list object (classic Python pitfall); use e.g. [[] for _ in range(n)].

def _extract_questions(self, gt: Any, n: int) -> List[List[str]]:
if isinstance(gt, (list, tuple)) and len(gt) >= n:
out = []
for i in range(n):
v = gt[i]
if isinstance(v, dict) and "questions" in v:
qs = v["questions"]
out.append(list(qs.values()) if isinstance(qs, dict) else list(qs))
else:
out.append([])
return out
return [[]] * n

(No CLAUDE.md in repo root or modified directories; review used PR head 7435679005957aaa13413714e7af9aca27938812.)

Generated with Claude Code

@davidberenstein1957
Copy link
Copy Markdown
Member Author

Update on review #545 (comment)

  1. aggregation / GenEval all_or_nothing — Already implemented on this branch: in QAAccuracyMetric.update, per-image score is 1.0 only when every per-question score passes (all_or_nothing), otherwise mean over question scores for the default. Covered by tests/evaluation/test_vlm_metrics.py (test_qa_accuracy_all_or_nothing_*).

  2. _extract_questions fallback — Fixed: replaced return [[]] * n with return [[] for _ in range(n)] so the fallback path does not alias one empty list across n slots (src/pruna/evaluation/metrics/metric_qa_accuracy.py).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants