Skip to content

Commit 03a0568

Browse files
authored
Evaluation: Added type for dataset (#641)
1 parent a0c1f24 commit 03a0568

4 files changed

Lines changed: 355 additions & 1 deletion

File tree

backend/app/crud/evaluations/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from app.crud.evaluations.score import EvaluationScore
1212
from app.models import EvaluationRun
1313
from app.models.llm.request import ConfigBlob, LLMCallConfig
14+
from app.models.stt_evaluation import EvaluationType
1415
from app.services.llm.jobs import resolve_config_blob
1516

1617
from app.core.db import engine
@@ -80,6 +81,7 @@ def create_evaluation_run(
8081
run_name=run_name,
8182
dataset_name=dataset_name,
8283
dataset_id=dataset_id,
84+
type=EvaluationType.TEXT.value,
8385
config_id=config_id,
8486
config_version=config_version,
8587
status="pending",
@@ -129,6 +131,7 @@ def list_evaluation_runs(
129131
select(EvaluationRun)
130132
.where(EvaluationRun.organization_id == organization_id)
131133
.where(EvaluationRun.project_id == project_id)
134+
.where(EvaluationRun.type == EvaluationType.TEXT.value)
132135
.order_by(EvaluationRun.inserted_at.desc())
133136
.limit(limit)
134137
.offset(offset)
@@ -167,6 +170,7 @@ def get_evaluation_run_by_id(
167170
.where(EvaluationRun.id == evaluation_id)
168171
.where(EvaluationRun.organization_id == organization_id)
169172
.where(EvaluationRun.project_id == project_id)
173+
.where(EvaluationRun.type == EvaluationType.TEXT.value)
170174
)
171175

172176
eval_run = session.exec(statement).first()

backend/app/crud/evaluations/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from app.core.util import now
2424
from app.models import EvaluationDataset, EvaluationRun
25+
from app.models.stt_evaluation import EvaluationType
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -60,6 +61,7 @@ def create_evaluation_dataset(
6061
dataset = EvaluationDataset(
6162
name=name,
6263
description=description,
64+
type=EvaluationType.TEXT.value,
6365
dataset_metadata=dataset_metadata,
6466
object_store_url=object_store_url,
6567
langfuse_dataset_id=langfuse_dataset_id,
@@ -122,6 +124,7 @@ def get_dataset_by_id(
122124
.where(EvaluationDataset.id == dataset_id)
123125
.where(EvaluationDataset.organization_id == organization_id)
124126
.where(EvaluationDataset.project_id == project_id)
127+
.where(EvaluationDataset.type == EvaluationType.TEXT.value)
125128
)
126129

127130
dataset = session.exec(statement).first()
@@ -158,6 +161,7 @@ def get_dataset_by_name(
158161
.where(EvaluationDataset.name == name)
159162
.where(EvaluationDataset.organization_id == organization_id)
160163
.where(EvaluationDataset.project_id == project_id)
164+
.where(EvaluationDataset.type == EvaluationType.TEXT.value)
161165
)
162166

163167
dataset = session.exec(statement).first()
@@ -194,6 +198,7 @@ def list_datasets(
194198
select(EvaluationDataset)
195199
.where(EvaluationDataset.organization_id == organization_id)
196200
.where(EvaluationDataset.project_id == project_id)
201+
.where(EvaluationDataset.type == EvaluationType.TEXT.value)
197202
.order_by(EvaluationDataset.inserted_at.desc())
198203
.limit(limit)
199204
.offset(offset)
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from uuid import uuid4
2+
3+
from sqlmodel import Session, select
4+
5+
from app.core.util import now
6+
from app.crud.evaluations.core import (
7+
create_evaluation_run,
8+
get_evaluation_run_by_id,
9+
list_evaluation_runs,
10+
)
11+
from app.crud.evaluations.dataset import create_evaluation_dataset
12+
from app.models import EvaluationRun, Organization, Project
13+
from app.models.stt_evaluation import EvaluationType
14+
15+
16+
def _create_config(db: Session, project_id: int) -> tuple:
17+
"""Helper to create a config and config_version for evaluation runs."""
18+
from app.models.config import Config, ConfigVersion
19+
20+
config = Config(
21+
name="test_config",
22+
project_id=project_id,
23+
inserted_at=now(),
24+
updated_at=now(),
25+
)
26+
db.add(config)
27+
db.commit()
28+
db.refresh(config)
29+
30+
config_version = ConfigVersion(
31+
config_id=config.id,
32+
version=1,
33+
config_blob={"completion": {"params": {"model": "gpt-4o"}}},
34+
inserted_at=now(),
35+
updated_at=now(),
36+
)
37+
db.add(config_version)
38+
db.commit()
39+
db.refresh(config_version)
40+
41+
return config.id, config_version.version
42+
43+
44+
class TestCreateEvaluationRun:
45+
"""Test creating evaluation runs."""
46+
47+
def test_create_evaluation_run_sets_text_type(self, db: Session) -> None:
48+
"""Test that create_evaluation_run sets type to TEXT."""
49+
org = db.exec(select(Organization)).first()
50+
project = db.exec(
51+
select(Project).where(Project.organization_id == org.id)
52+
).first()
53+
54+
dataset = create_evaluation_dataset(
55+
session=db,
56+
name="test_dataset_run_type",
57+
dataset_metadata={"original_items_count": 10},
58+
organization_id=org.id,
59+
project_id=project.id,
60+
)
61+
62+
config_id, config_version = _create_config(db, project.id)
63+
64+
eval_run = create_evaluation_run(
65+
session=db,
66+
run_name="test_run",
67+
dataset_name=dataset.name,
68+
dataset_id=dataset.id,
69+
config_id=config_id,
70+
config_version=config_version,
71+
organization_id=org.id,
72+
project_id=project.id,
73+
)
74+
75+
assert eval_run.id is not None
76+
assert eval_run.type == EvaluationType.TEXT.value
77+
assert eval_run.status == "pending"
78+
assert eval_run.run_name == "test_run"
79+
80+
81+
class TestGetEvaluationRunById:
82+
"""Test fetching evaluation runs by ID."""
83+
84+
def test_get_evaluation_run_by_id_success(self, db: Session) -> None:
85+
"""Test fetching an existing evaluation run by ID."""
86+
org = db.exec(select(Organization)).first()
87+
project = db.exec(
88+
select(Project).where(Project.organization_id == org.id)
89+
).first()
90+
91+
dataset = create_evaluation_dataset(
92+
session=db,
93+
name="test_dataset_get_run",
94+
dataset_metadata={"original_items_count": 10},
95+
organization_id=org.id,
96+
project_id=project.id,
97+
)
98+
99+
config_id, config_version = _create_config(db, project.id)
100+
101+
eval_run = create_evaluation_run(
102+
session=db,
103+
run_name="test_get_run",
104+
dataset_name=dataset.name,
105+
dataset_id=dataset.id,
106+
config_id=config_id,
107+
config_version=config_version,
108+
organization_id=org.id,
109+
project_id=project.id,
110+
)
111+
112+
fetched = get_evaluation_run_by_id(
113+
session=db,
114+
evaluation_id=eval_run.id,
115+
organization_id=org.id,
116+
project_id=project.id,
117+
)
118+
119+
assert fetched is not None
120+
assert fetched.id == eval_run.id
121+
assert fetched.run_name == "test_get_run"
122+
123+
def test_get_evaluation_run_by_id_not_found(self, db: Session) -> None:
124+
"""Test fetching a non-existent evaluation run."""
125+
org = db.exec(select(Organization)).first()
126+
project = db.exec(
127+
select(Project).where(Project.organization_id == org.id)
128+
).first()
129+
130+
fetched = get_evaluation_run_by_id(
131+
session=db,
132+
evaluation_id=99999,
133+
organization_id=org.id,
134+
project_id=project.id,
135+
)
136+
137+
assert fetched is None
138+
139+
def test_get_evaluation_run_by_id_excludes_non_text_type(self, db: Session) -> None:
140+
"""Test that get_evaluation_run_by_id excludes runs with non-text type."""
141+
org = db.exec(select(Organization)).first()
142+
project = db.exec(
143+
select(Project).where(Project.organization_id == org.id)
144+
).first()
145+
146+
dataset = create_evaluation_dataset(
147+
session=db,
148+
name="test_dataset_exclude_run",
149+
dataset_metadata={"original_items_count": 10},
150+
organization_id=org.id,
151+
project_id=project.id,
152+
)
153+
154+
config_id, config_version = _create_config(db, project.id)
155+
156+
eval_run = create_evaluation_run(
157+
session=db,
158+
run_name="test_stt_run",
159+
dataset_name=dataset.name,
160+
dataset_id=dataset.id,
161+
config_id=config_id,
162+
config_version=config_version,
163+
organization_id=org.id,
164+
project_id=project.id,
165+
)
166+
167+
# Manually update type to STT to simulate a non-text run
168+
eval_run.type = EvaluationType.STT.value
169+
db.add(eval_run)
170+
db.commit()
171+
172+
fetched = get_evaluation_run_by_id(
173+
session=db,
174+
evaluation_id=eval_run.id,
175+
organization_id=org.id,
176+
project_id=project.id,
177+
)
178+
179+
assert fetched is None
180+
181+
182+
class TestListEvaluationRuns:
183+
"""Test listing evaluation runs."""
184+
185+
def test_list_evaluation_runs_empty(self, db: Session) -> None:
186+
"""Test listing evaluation runs when none exist."""
187+
org = db.exec(select(Organization)).first()
188+
project = db.exec(
189+
select(Project).where(Project.organization_id == org.id)
190+
).first()
191+
192+
runs = list_evaluation_runs(
193+
session=db, organization_id=org.id, project_id=project.id
194+
)
195+
196+
assert len(runs) == 0
197+
198+
def test_list_evaluation_runs_excludes_non_text_type(self, db: Session) -> None:
199+
"""Test that list_evaluation_runs only returns text type runs."""
200+
org = db.exec(select(Organization)).first()
201+
project = db.exec(
202+
select(Project).where(Project.organization_id == org.id)
203+
).first()
204+
205+
dataset = create_evaluation_dataset(
206+
session=db,
207+
name="test_dataset_list_runs",
208+
dataset_metadata={"original_items_count": 10},
209+
organization_id=org.id,
210+
project_id=project.id,
211+
)
212+
213+
config_id, config_version = _create_config(db, project.id)
214+
215+
# Create text evaluation runs
216+
for i in range(3):
217+
create_evaluation_run(
218+
session=db,
219+
run_name=f"text_run_{i}",
220+
dataset_name=dataset.name,
221+
dataset_id=dataset.id,
222+
config_id=config_id,
223+
config_version=config_version,
224+
organization_id=org.id,
225+
project_id=project.id,
226+
)
227+
228+
# Create a non-text evaluation run by updating type after creation
229+
stt_run = create_evaluation_run(
230+
session=db,
231+
run_name="stt_run",
232+
dataset_name=dataset.name,
233+
dataset_id=dataset.id,
234+
config_id=config_id,
235+
config_version=config_version,
236+
organization_id=org.id,
237+
project_id=project.id,
238+
)
239+
stt_run.type = EvaluationType.STT.value
240+
db.add(stt_run)
241+
db.commit()
242+
243+
runs = list_evaluation_runs(
244+
session=db, organization_id=org.id, project_id=project.id
245+
)
246+
247+
assert len(runs) == 3
248+
assert all(r.type == EvaluationType.TEXT.value for r in runs)

0 commit comments

Comments
 (0)