diff --git a/python/cuopt_server/cuopt_server/tests/utils/utils.py b/python/cuopt_server/cuopt_server/tests/utils/utils.py index 022f5aac8e..eaefdb2a25 100644 --- a/python/cuopt_server/cuopt_server/tests/utils/utils.py +++ b/python/cuopt_server/cuopt_server/tests/utils/utils.py @@ -267,6 +267,8 @@ def cuopt_service_sync( # Fixture and client to allow full cuopt service # to run as a separate process for multiple tests cuoptmain = None +# True after server has passed initial healthcheck; used to fail-fast on later crashes +_server_was_up = False # Use module name instead of file path to ensure we use the installed package server_script = "-m" server_module = "cuopt_server.cuopt_service" @@ -296,7 +298,21 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) +def _exit_if_server_gone(exc): + """Exit pytest immediately when server was up but is now unreachable (e.g. crashed).""" + global _server_was_up + if _server_was_up: + pytest.exit( + "cuOpt server stopped responding (connection error). " + "Skipping remaining server tests to reduce log noise. " + "Check server startup and dependencies (e.g. cudf/GPU).", + returncode=1, + ) + raise exc + + def spinup_wait(): + global _server_was_up client = RequestClient() count = 0 result = None @@ -309,7 +325,13 @@ def spinup_wait(): break except Exception: time.sleep(1) - assert result.status_code == 200 + if result is None or result.status_code != 200: + pytest.exit( + "cuOpt server failed to pass healthcheck after 30s. " + "Skipping all server tests. Check server logs for startup errors (e.g. cudf/GPU).", + returncode=1, + ) + _server_was_up = True @pytest.fixture(scope="session") @@ -347,9 +369,12 @@ def poll_for_completion(self, reqId, delete=True): cnt = 0 headers = {"Accept": "application/json"} while True: - res = requests.get( - self.url + f"/cuopt/solution/{reqId}", headers=headers - ) + try: + res = requests.get( + self.url + f"/cuopt/solution/{reqId}", headers=headers + ) + except (requests.ConnectionError, requests.ConnectTimeout) as e: + _exit_if_server_gone(e) if "response" in res.json() or "error" in res.json(): break time.sleep(1) @@ -361,6 +386,8 @@ def poll_for_completion(self, reqId, delete=True): requests.delete( self.url + f"/cuopt/solution/{reqId}", headers=headers ) + except (requests.ConnectionError, requests.ConnectTimeout) as e: + _exit_if_server_gone(e) except Exception: pass return res @@ -375,13 +402,16 @@ def post( block=True, delete=True, ): - res = requests.post( - self.url + endpoint, - params=params, - headers=headers, - json=json, - data=data, - ) + try: + res = requests.post( + self.url + endpoint, + params=params, + headers=headers, + json=json, + data=data, + ) + except (requests.ConnectionError, requests.ConnectTimeout) as e: + _exit_if_server_gone(e) # cuopt/cuot is already blocking, don't ever poll if endpoint == "/cuopt/cuopt": @@ -397,9 +427,20 @@ def post( return self.poll_for_completion(res.json()["reqId"], delete) def get(self, endpoint, headers=None, json=None): - return requests.get(self.url + endpoint, headers=headers, json=json) + try: + return requests.get( + self.url + endpoint, headers=headers, json=json + ) + except (requests.ConnectionError, requests.ConnectTimeout) as e: + _exit_if_server_gone(e) def delete(self, endpoint, headers=None, json=None, params=None): - return requests.delete( - self.url + endpoint, params=params, headers=headers, json=json - ) + try: + return requests.delete( + self.url + endpoint, + params=params, + headers=headers, + json=json, + ) + except (requests.ConnectionError, requests.ConnectTimeout) as e: + _exit_if_server_gone(e)