diff --git a/libbs/api/decompiler_client.py b/libbs/api/decompiler_client.py index 1fb2ed5..60f17c9 100644 --- a/libbs/api/decompiler_client.py +++ b/libbs/api/decompiler_client.py @@ -368,14 +368,26 @@ def __init__(self, _l.info(f"DecompilerClient connected to {socket_path}") + def _create_and_connect_socket(self) -> socket.socket: + """Create and connect a socket handling both AF_UNIX and AF_INET fallbacks.""" + if hasattr(socket, "AF_UNIX"): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(self.socket_path) + else: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + with open(self.socket_path, 'r') as f: + port = int(f.read().strip()) + sock.connect(('127.0.0.1', port)) + return sock + def _connect(self): """Establish connection to the server""" try: - _l.debug(f"Attempting to connect to AF_UNIX socket at {self.socket_path}") + _l.debug(f"Attempting to connect to server at {self.socket_path}") - self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self._socket.settimeout(self.timeout) - self._socket.connect(self.socket_path) + self._socket = self._create_and_connect_socket() _l.debug("Socket connection established") @@ -589,9 +601,7 @@ def _start_event_listener(self) -> None: # Create a separate socket connection for receiving events try: - self._event_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self._event_socket.settimeout(self.timeout) - self._event_socket.connect(self.socket_path) + self._event_socket = self._create_and_connect_socket() # Send subscription request to server SocketProtocol.send_message(self._event_socket, {"type": "subscribe_events"}) diff --git a/libbs/api/decompiler_server.py b/libbs/api/decompiler_server.py index 9d1482f..e1c1d0c 100644 --- a/libbs/api/decompiler_server.py +++ b/libbs/api/decompiler_server.py @@ -419,18 +419,23 @@ def start(self): _l.info(f"Starting DecompilerServer on {self.socket_path}") - # Create AF_UNIX socket - self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - # Set timeout so accept() doesn't block forever - self._server_socket.settimeout(1.0) - - # Remove socket file if it exists - if os.path.exists(self.socket_path): - os.unlink(self.socket_path) - - # Bind and listen - self._server_socket.bind(self.socket_path) + # Create socket (AF_UNIX if available, else AF_INET) + if hasattr(socket, "AF_UNIX"): + self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._server_socket.settimeout(1.0) + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + self._server_socket.bind(self.socket_path) + else: + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.settimeout(1.0) + self._server_socket.bind(('127.0.0.1', 0)) + port = self._server_socket.getsockname()[1] + try: + with open(self.socket_path, 'w') as f: + f.write(str(port)) + except Exception as e: + _l.error(f"Failed to write port to {self.socket_path}: {e}") self._server_socket.listen(5) # Set running flag before starting thread diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 97f4d13..49b82a3 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1,4 +1,6 @@ import os +import contextlib +import socket import tempfile import threading import time @@ -18,6 +20,19 @@ FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware" + +@contextlib.contextmanager +def simulate_no_af_unix(): + if hasattr(socket, "AF_UNIX"): + af_unix_val = socket.AF_UNIX + delattr(socket, "AF_UNIX") + try: + yield + finally: + setattr(socket, "AF_UNIX", af_unix_val) + else: + yield + class TestClientServer(unittest.TestCase): """Test the new AF_UNIX socket-based DecompilerClient and DecompilerServer""" @@ -67,6 +82,38 @@ def test_server_startup_and_client_connection(self): self.assertIsNotNone(self.client.binary_path) self.assertIsNotNone(self.client.binary_hash) self.assertTrue(self.client.decompiler_available) + + def test_inet_fallback(self): + """Test the AF_INET fallback mechanism when AF_UNIX is missing""" + with simulate_no_af_unix(): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as proj_dir: + self.server = DecompilerServer( + force_decompiler=GHIDRA_DECOMPILER, + headless=True, + binary_path=FAUXWARE_PATH, + project_location=proj_dir, + project_name="test_fauxware_inet" + ) + self.server.start() + + # Give server time to start + time.sleep(2) + + # Verify that it binds to a port and writes to the socket path + self.assertTrue(os.path.exists(self.server.socket_path)) + with open(self.server.socket_path, 'r') as f: + port_str = f.read().strip() + self.assertTrue(port_str.isdigit()) + + # Connect client + self.client = DecompilerClient(socket_path=self.server.socket_path) + + # Verify connection works + self.assertTrue(self.client.is_connected()) + self.assertTrue(self.client.ping()) + + self.client.shutdown() + self.server.stop() def test_artifact_collections_match_local(self): """Test that client artifact collections behave like local interface"""