Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions libbs/api/decompiler_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,26 @@ def __init__(self,

_l.info(f"DecompilerClient connected to {socket_path}")

def _create_and_connect_socket(self) -> socket.socket:
"""Create and connect a socket handling both AF_UNIX and AF_INET fallbacks."""
if hasattr(socket, "AF_UNIX"):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(self.socket_path)
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
with open(self.socket_path, 'r') as f:
port = int(f.read().strip())
sock.connect(('127.0.0.1', port))
return sock

def _connect(self):
"""Establish connection to the server"""
try:
_l.debug(f"Attempting to connect to AF_UNIX socket at {self.socket_path}")
_l.debug(f"Attempting to connect to server at {self.socket_path}")

self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._socket.settimeout(self.timeout)
self._socket.connect(self.socket_path)
self._socket = self._create_and_connect_socket()

_l.debug("Socket connection established")

Expand Down Expand Up @@ -589,9 +601,7 @@ def _start_event_listener(self) -> None:

# Create a separate socket connection for receiving events
try:
self._event_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._event_socket.settimeout(self.timeout)
self._event_socket.connect(self.socket_path)
self._event_socket = self._create_and_connect_socket()

# Send subscription request to server
SocketProtocol.send_message(self._event_socket, {"type": "subscribe_events"})
Expand Down
29 changes: 17 additions & 12 deletions libbs/api/decompiler_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,18 +419,23 @@ def start(self):

_l.info(f"Starting DecompilerServer on {self.socket_path}")

# Create AF_UNIX socket
self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

# Set timeout so accept() doesn't block forever
self._server_socket.settimeout(1.0)

# Remove socket file if it exists
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)

# Bind and listen
self._server_socket.bind(self.socket_path)
# Create socket (AF_UNIX if available, else AF_INET)
if hasattr(socket, "AF_UNIX"):
self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._server_socket.settimeout(1.0)
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
self._server_socket.bind(self.socket_path)
else:
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._server_socket.settimeout(1.0)
self._server_socket.bind(('127.0.0.1', 0))
port = self._server_socket.getsockname()[1]
try:
with open(self.socket_path, 'w') as f:
f.write(str(port))
except Exception as e:
_l.error(f"Failed to write port to {self.socket_path}: {e}")
self._server_socket.listen(5)

# Set running flag before starting thread
Expand Down
47 changes: 47 additions & 0 deletions tests/test_client_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import contextlib
import socket
import tempfile
import threading
import time
Expand All @@ -18,6 +20,19 @@

FAUXWARE_PATH = TEST_BINARIES_DIR / "fauxware"


@contextlib.contextmanager
def simulate_no_af_unix():
if hasattr(socket, "AF_UNIX"):
af_unix_val = socket.AF_UNIX
delattr(socket, "AF_UNIX")
try:
yield
finally:
setattr(socket, "AF_UNIX", af_unix_val)
else:
yield

class TestClientServer(unittest.TestCase):
"""Test the new AF_UNIX socket-based DecompilerClient and DecompilerServer"""

Expand Down Expand Up @@ -67,6 +82,38 @@ def test_server_startup_and_client_connection(self):
self.assertIsNotNone(self.client.binary_path)
self.assertIsNotNone(self.client.binary_hash)
self.assertTrue(self.client.decompiler_available)

def test_inet_fallback(self):
"""Test the AF_INET fallback mechanism when AF_UNIX is missing"""
with simulate_no_af_unix():
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as proj_dir:
self.server = DecompilerServer(
force_decompiler=GHIDRA_DECOMPILER,
headless=True,
binary_path=FAUXWARE_PATH,
project_location=proj_dir,
project_name="test_fauxware_inet"
)
self.server.start()

# Give server time to start
time.sleep(2)

# Verify that it binds to a port and writes to the socket path
self.assertTrue(os.path.exists(self.server.socket_path))
with open(self.server.socket_path, 'r') as f:
port_str = f.read().strip()
self.assertTrue(port_str.isdigit())

# Connect client
self.client = DecompilerClient(socket_path=self.server.socket_path)

# Verify connection works
self.assertTrue(self.client.is_connected())
self.assertTrue(self.client.ping())

self.client.shutdown()
self.server.stop()

def test_artifact_collections_match_local(self):
"""Test that client artifact collections behave like local interface"""
Expand Down