diff --git a/src/memos/utils.py b/src/memos/utils.py index f7111f8ad..bbd7d0e7c 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -165,6 +165,7 @@ def wrapper(*args, **kwargs): exc_message = None result = None success_flag = False + exception_to_raise: Exception | None = None try: result = fn(*args, **kwargs) @@ -179,6 +180,12 @@ def wrapper(*args, **kwargs): if fallback is not None and callable(fallback): result = fallback(e, *args, **kwargs) return result + # No fallback: remember the exception so we can re-raise it + # *after* the ``finally`` block has emitted the status log. + # A bare ``raise`` here would also work, but storing the + # reference makes the intent explicit and keeps the finally + # block solely responsible for logging. + exception_to_raise = e finally: elapsed_ms = (time.perf_counter() - start) * 1000.0 @@ -218,6 +225,13 @@ def wrapper(*args, **kwargs): logger.info(msg) + # Re-raise *after* the finally block has run so the status log + # is still emitted for failures without a fallback. Using + # ``raise `` (not bare ``raise``) here because the except + # block has already exited. + if exception_to_raise is not None: + raise exception_to_raise + return wrapper if func is None: @@ -227,6 +241,7 @@ def wrapper(*args, **kwargs): def timed(func=None, *, log=True, log_prefix=""): def decorator(fn): + @functools.wraps(fn) def wrapper(*args, **kwargs): start = time.perf_counter() result = fn(*args, **kwargs) diff --git a/tests/test_utils_timing.py b/tests/test_utils_timing.py index b4d5cb989..2871b81cc 100644 --- a/tests/test_utils_timing.py +++ b/tests/test_utils_timing.py @@ -292,6 +292,17 @@ def parens(): assert bare() == 1 assert parens() == 2 + def test_preserves_function_metadata(self): + """@timed must preserve __name__ / __doc__ via functools.wraps.""" + + @timed + def documented_func(): + """I have a docstring.""" + return 42 + + assert documented_func.__name__ == "documented_func" + assert documented_func.__doc__ == "I have a docstring." + # =========================================================================== # timed_with_status — regression tests @@ -313,17 +324,46 @@ def ok_func(): assert "ok_func" in logs[0] def test_failure_logging_no_fallback(self, caplog): + """Without a fallback the original exception must propagate. + + The [TIMER_WITH_STATUS] log line is emitted from ``finally`` so it + is still produced *before* the exception unwinds out of the + wrapper — caplog captures it either way. + """ + @timed_with_status def fail_func(): raise RuntimeError("bad") - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.INFO), pytest.raises(RuntimeError, match="bad"): fail_func() logs = _collect_timer_with_status_logs(caplog) assert len(logs) == 1 assert "status: FAILED" in logs[0] assert "RuntimeError" in logs[0] + def test_failure_no_fallback_preserves_original_exception(self): + """Re-raise must keep the original exception identity / chain. + + Using a sentinel attribute on the raised exception is the simplest + way to assert the *same* object is propagated (no wrapping in a + new exception type, no ``raise ... from None``). + """ + marker = object() + + @timed_with_status + def fail_func(): + err = RuntimeError("identity") + err.marker = marker # type: ignore[attr-defined] + raise err + + with pytest.raises(RuntimeError) as excinfo: + fail_func() + assert getattr(excinfo.value, "marker", None) is marker + # No exception chaining was introduced by the decorator. + assert excinfo.value.__cause__ is None + assert excinfo.value.__suppress_context__ is False + def test_failure_with_fallback(self, caplog): @timed_with_status(fallback=lambda e, *a, **kw: "fallback_val") def fail_func():