Skip to content

Commit 6db990d

Browse files
authored
Chore: Refactor the intervals check and move progress reporting into the scheduler (#4879)
1 parent 0edfea7 commit 6db990d

File tree

5 files changed

+117
-64
lines changed

5 files changed

+117
-64
lines changed

sqlmesh/core/model/definition.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,15 +627,23 @@ def _render(e: exp.Expression) -> str | int | float | bool:
627627
{k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name
628628
]
629629

630-
def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]:
631-
return {
630+
def render_signal_calls(self) -> EvaluatableSignals:
631+
python_env = self.python_env
632+
env = prepare_env(python_env)
633+
signals_to_kwargs = {
632634
name: {
633635
k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items()
634636
}
635637
for name, kwargs in self.signals
636638
if name
637639
}
638640

641+
return EvaluatableSignals(
642+
signals_to_kwargs=signals_to_kwargs,
643+
python_env=python_env,
644+
prepared_python_env=env,
645+
)
646+
639647
def render_merge_filter(
640648
self,
641649
*,
@@ -1857,6 +1865,15 @@ class AuditResult(PydanticModel):
18571865
blocking: bool = True
18581866

18591867

1868+
class EvaluatableSignals(PydanticModel):
1869+
signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]
1870+
"""A mapping of signal names to the kwargs passed to the signal."""
1871+
python_env: t.Dict[str, Executable]
1872+
"""The Python environment that should be used to evaluated the rendered signal calls."""
1873+
prepared_python_env: t.Dict[str, t.Any]
1874+
"""The prepared Python environment that should be used to evaluated the rendered signal calls."""
1875+
1876+
18601877
def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]:
18611878
if not blueprints:
18621879
return [None]

sqlmesh/core/scheduler.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import logging
33
import typing as t
4+
import time
45
from sqlglot import exp
56
from sqlmesh.core import constants as c
67
from sqlmesh.core.console import Console, get_console
@@ -24,6 +25,7 @@
2425
snapshots_to_dag,
2526
Intervals,
2627
)
28+
from sqlmesh.core.snapshot.definition import check_ready_intervals
2729
from sqlmesh.core.snapshot.definition import (
2830
Interval,
2931
expand_range,
@@ -39,7 +41,16 @@
3941
to_timestamp,
4042
validate_date_range,
4143
)
42-
from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError
44+
from sqlmesh.utils.errors import (
45+
AuditError,
46+
NodeAuditsErrors,
47+
CircuitBreakerError,
48+
SQLMeshError,
49+
SignalEvalError,
50+
)
51+
52+
if t.TYPE_CHECKING:
53+
from sqlmesh.core.context import ExecutionContext
4354

4455
logger = logging.getLogger(__name__)
4556
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
@@ -304,12 +315,11 @@ def batch_intervals(
304315
default_catalog=self.default_catalog,
305316
)
306317

307-
intervals = snapshot.check_ready_intervals(
318+
intervals = self._check_ready_intervals(
319+
snapshot,
308320
intervals,
309321
context,
310-
console=self.console,
311-
default_catalog=self.default_catalog,
312-
environment_naming_info=environment_naming_info,
322+
environment_naming_info,
313323
)
314324
unready -= set(intervals)
315325

@@ -709,6 +719,76 @@ def _audit_snapshot(
709719

710720
return audit_results
711721

722+
def _check_ready_intervals(
723+
self,
724+
snapshot: Snapshot,
725+
intervals: Intervals,
726+
context: ExecutionContext,
727+
environment_naming_info: EnvironmentNamingInfo,
728+
) -> Intervals:
729+
"""Checks if the intervals are ready for evaluation for the given snapshot.
730+
731+
This implementation also includes the signal progress tracking.
732+
Note that this will handle gaps in the provided intervals. The returned intervals
733+
may introduce new gaps.
734+
735+
Args:
736+
snapshot: The snapshot to check.
737+
intervals: The intervals to check.
738+
context: The context to use.
739+
environment_naming_info: The environment naming info to use.
740+
741+
Returns:
742+
The intervals that are ready for evaluation.
743+
"""
744+
signals = snapshot.is_model and snapshot.model.render_signal_calls()
745+
746+
if not signals:
747+
return intervals
748+
749+
self.console.start_signal_progress(
750+
snapshot,
751+
self.default_catalog,
752+
environment_naming_info or EnvironmentNamingInfo(),
753+
)
754+
755+
for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()):
756+
# Capture intervals before signal check for display
757+
intervals_to_check = merge_intervals(intervals)
758+
759+
signal_start_ts = time.perf_counter()
760+
761+
try:
762+
intervals = check_ready_intervals(
763+
signals.prepared_python_env[signal_name],
764+
intervals,
765+
context,
766+
python_env=signals.python_env,
767+
dialect=snapshot.model.dialect,
768+
path=snapshot.model._path,
769+
kwargs=kwargs,
770+
)
771+
except SQLMeshError as e:
772+
raise SignalEvalError(
773+
f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}"
774+
)
775+
776+
duration = time.perf_counter() - signal_start_ts
777+
778+
self.console.update_signal_progress(
779+
snapshot=snapshot,
780+
signal_name=signal_name,
781+
signal_idx=signal_idx,
782+
total_signals=len(signals.signals_to_kwargs),
783+
ready_intervals=merge_intervals(intervals),
784+
check_intervals=intervals_to_check,
785+
duration=duration,
786+
)
787+
788+
self.console.stop_signal_progress()
789+
790+
return intervals
791+
712792

713793
def merged_missing_intervals(
714794
snapshots: t.Collection[Snapshot],

sqlmesh/core/snapshot/definition.py

Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import sys
4-
import time
54
import typing as t
65
from collections import defaultdict
76
from datetime import datetime, timedelta
@@ -42,16 +41,13 @@
4241
)
4342
from sqlmesh.utils.errors import SQLMeshError, SignalEvalError
4443
from sqlmesh.utils.metaprogramming import (
45-
prepare_env,
46-
print_exception,
4744
format_evaluated_code_exception,
4845
Executable,
4946
)
5047
from sqlmesh.utils.hashing import hash_data
5148
from sqlmesh.utils.pydantic import PydanticModel, field_validator
5249

5350
if t.TYPE_CHECKING:
54-
from sqlmesh.core.console import Console
5551
from sqlglot.dialects.dialect import DialectType
5652
from sqlmesh.core.environment import EnvironmentNamingInfo
5753
from sqlmesh.core.context import ExecutionContext
@@ -971,69 +967,31 @@ def check_ready_intervals(
971967
self,
972968
intervals: Intervals,
973969
context: ExecutionContext,
974-
console: t.Optional[Console] = None,
975-
default_catalog: t.Optional[str] = None,
976-
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
977970
) -> Intervals:
978971
"""Returns a list of intervals that are considered ready by the provided signal.
979972
980973
Note that this will handle gaps in the provided intervals. The returned intervals
981974
may introduce new gaps.
982975
"""
983976
signals = self.is_model and self.model.render_signal_calls()
984-
985977
if not signals:
986978
return intervals
987979

988-
python_env = self.model.python_env
989-
env = prepare_env(python_env)
990-
991-
if console:
992-
console.start_signal_progress(
993-
self,
994-
default_catalog,
995-
environment_naming_info or EnvironmentNamingInfo(),
996-
)
997-
998-
for signal_idx, (signal_name, kwargs) in enumerate(signals.items()):
999-
# Capture intervals before signal check for display
1000-
intervals_to_check = merge_intervals(intervals)
1001-
1002-
signal_start_ts = time.perf_counter()
1003-
980+
for signal_name, kwargs in signals.signals_to_kwargs.items():
1004981
try:
1005-
intervals = _check_ready_intervals(
1006-
env[signal_name],
982+
intervals = check_ready_intervals(
983+
signals.prepared_python_env[signal_name],
1007984
intervals,
1008985
context,
1009-
python_env=python_env,
986+
python_env=signals.python_env,
1010987
dialect=self.model.dialect,
1011988
path=self.model._path,
1012989
kwargs=kwargs,
1013990
)
1014991
except SQLMeshError as e:
1015-
print_exception(e, python_env)
1016-
raise SQLMeshError(
992+
raise SignalEvalError(
1017993
f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}"
1018994
)
1019-
1020-
duration = time.perf_counter() - signal_start_ts
1021-
1022-
if console:
1023-
console.update_signal_progress(
1024-
snapshot=self,
1025-
signal_name=signal_name,
1026-
signal_idx=signal_idx,
1027-
total_signals=len(signals),
1028-
ready_intervals=merge_intervals(intervals),
1029-
check_intervals=intervals_to_check,
1030-
duration=duration,
1031-
)
1032-
1033-
# Stop signal progress tracking
1034-
if console:
1035-
console.stop_signal_progress()
1036-
1037995
return intervals
1038996

1039997
def categorize_as(self, category: SnapshotChangeCategory) -> None:
@@ -2229,7 +2187,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]:
22292187
return contiguous_intervals
22302188

22312189

2232-
def _check_ready_intervals(
2190+
def check_ready_intervals(
22332191
check: t.Callable,
22342192
intervals: Intervals,
22352193
context: ExecutionContext,

tests/core/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2118,7 +2118,7 @@ def test_check_intervals(sushi_context, mocker):
21182118
):
21192119
sushi_context.check_intervals(environment="dev", no_signals=False, select_models=[])
21202120

2121-
spy = mocker.spy(sqlmesh.core.snapshot.definition, "_check_ready_intervals")
2121+
spy = mocker.spy(sqlmesh.core.snapshot.definition, "check_ready_intervals")
21222122
intervals = sushi_context.check_intervals(environment=None, no_signals=False, select_models=[])
21232123

21242124
min_intervals = 19

tests/core/test_snapshot.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
apply_auto_restatements,
6060
display_name,
6161
get_next_model_interval_start,
62-
_check_ready_intervals,
62+
check_ready_intervals,
6363
_contiguous_intervals,
6464
)
6565
from sqlmesh.utils import AttributeDict
@@ -2540,7 +2540,7 @@ def test_contiguous_intervals():
25402540
def test_check_ready_intervals(mocker: MockerFixture):
25412541
def assert_always_signal(intervals):
25422542
assert (
2543-
_check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock())
2543+
check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock())
25442544
== intervals
25452545
)
25462546

@@ -2550,17 +2550,15 @@ def assert_always_signal(intervals):
25502550
assert_always_signal([(0, 1), (2, 3)])
25512551

25522552
def assert_never_signal(intervals):
2553-
assert (
2554-
_check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == []
2555-
)
2553+
assert check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == []
25562554

25572555
assert_never_signal([])
25582556
assert_never_signal([(0, 1)])
25592557
assert_never_signal([(0, 1), (1, 2)])
25602558
assert_never_signal([(0, 1), (2, 3)])
25612559

25622560
def assert_empty_signal(intervals):
2563-
assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == []
2561+
assert check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == []
25642562

25652563
assert_empty_signal([])
25662564
assert_empty_signal([(0, 1)])
@@ -2577,7 +2575,7 @@ def assert_check_intervals(
25772575
):
25782576
mock = mocker.Mock()
25792577
mock.side_effect = [to_intervals(r) for r in ready]
2580-
_check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected
2578+
check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected
25812579

25822580
assert_check_intervals([], [], [])
25832581
assert_check_intervals([(0, 1)], [[]], [])
@@ -2618,7 +2616,7 @@ def assert_check_intervals(
26182616
)
26192617

26202618
with pytest.raises(SignalEvalError):
2621-
_check_ready_intervals(
2619+
check_ready_intervals(
26222620
lambda _: (_ for _ in ()).throw(MemoryError("Some exception")),
26232621
[(0, 1), (1, 2)],
26242622
mocker.Mock(),

0 commit comments

Comments
 (0)