diff --git a/snap7/client.py b/snap7/client.py index 40bdb707..234e3eb3 100644 --- a/snap7/client.py +++ b/snap7/client.py @@ -5,7 +5,9 @@ """ import logging +import random import struct +import threading import time from typing import List, Any, Optional, Tuple, Union, Callable, cast from datetime import datetime @@ -58,12 +60,33 @@ class Client(ClientMixin): MAX_VARS = 20 # Max variables per multi-read/multi-write request - def __init__(self, lib_location: Optional[str] = None, **kwargs: Any): + def __init__( + self, + lib_location: Optional[str] = None, + *, + auto_reconnect: bool = False, + max_retries: int = 3, + retry_delay: float = 1.0, + backoff_factor: float = 2.0, + max_delay: float = 30.0, + heartbeat_interval: float = 0, + on_disconnect: Optional[Callable[[], None]] = None, + on_reconnect: Optional[Callable[[], None]] = None, + **kwargs: Any, + ): """ Initialize S7 client. Args: lib_location: Ignored. Kept for backwards compatibility. + auto_reconnect: Enable automatic reconnection on connection loss. + max_retries: Maximum number of reconnection attempts. + retry_delay: Initial delay between reconnection attempts in seconds. + backoff_factor: Multiplier for exponential backoff between retries. + max_delay: Maximum delay between reconnection attempts in seconds. + heartbeat_interval: Interval in seconds for heartbeat probes (0=disabled). + on_disconnect: Optional callback invoked when connection is lost. + on_reconnect: Optional callback invoked after successful reconnection. **kwargs: Ignored. Kept for backwards compatibility. """ self.connection: Optional[ISOTCPConnection] = None @@ -107,8 +130,37 @@ def __init__(self, lib_location: Optional[str] = None, **kwargs: Any): self._last_error = 0 self._exec_time = 0 + # Auto-reconnection settings + self._auto_reconnect = auto_reconnect + self._max_retries = max_retries + self._retry_delay = retry_delay + self._backoff_factor = backoff_factor + self._max_delay = max_delay + self._on_disconnect = on_disconnect + self._on_reconnect = on_reconnect + + # Heartbeat settings + self._heartbeat_interval = heartbeat_interval + self._heartbeat_thread: Optional[threading.Thread] = None + self._heartbeat_stop_event = threading.Event() + self._is_alive = False + + # Lock for thread safety during reconnection + self._reconnect_lock = threading.Lock() + logger.info("S7Client initialized (pure Python implementation)") + @property + def is_alive(self) -> bool: + """Whether the connection is alive according to the last heartbeat probe. + + Returns True if heartbeat is disabled but the client is connected, + or if the last heartbeat probe succeeded. + """ + if self._heartbeat_interval <= 0: + return self.connected + return self._is_alive + def _get_connection(self) -> ISOTCPConnection: """Get connection, raising if not connected.""" if self.connection is None: @@ -150,6 +202,152 @@ def _send_receive(self, request: bytes, max_stale_retries: int = 3) -> dict[str, raise S7ProtocolError("Failed to receive valid response") # Should not reach here + def _send_receive_with_reconnect(self, request_builder: Callable[[], bytes], max_stale_retries: int = 3) -> dict[str, Any]: + """Send a request with automatic reconnection on connection loss. + + If auto_reconnect is disabled, behaves identically to _send_receive. + When enabled, catches connection errors, reconnects, rebuilds the request + (since the protocol sequence counter may have changed), and retries. + + Args: + request_builder: Callable that builds the request bytes. Called again + after reconnection to get a fresh request with updated sequence. + max_stale_retries: Max times to retry receive on stale packets. + + Returns: + Parsed S7 response dict. + """ + try: + return self._send_receive(request_builder(), max_stale_retries) + except (S7ConnectionError, OSError) as e: + if not self._auto_reconnect: + raise + logger.warning(f"Connection lost during operation: {e}") + self._do_reconnect() + return self._send_receive(request_builder(), max_stale_retries) + + def _do_reconnect(self) -> None: + """Perform reconnection with exponential backoff and jitter. + + Raises: + S7ConnectionError: If all reconnection attempts fail. + """ + with self._reconnect_lock: + # Check if another thread already reconnected + if self.connected and self.connection is not None: + try: + if self.connection.check_connection(): + return + except Exception: + pass + + self._is_alive = False + if self._on_disconnect is not None: + try: + self._on_disconnect() + except Exception: + logger.debug("on_disconnect callback raised an exception", exc_info=True) + + delay = self._retry_delay + last_error: Optional[Exception] = None + + for attempt in range(1, self._max_retries + 1): + logger.info(f"Reconnection attempt {attempt}/{self._max_retries}") + + # Clean up old connection + try: + if self.connection is not None: + self.connection.disconnect() + self.connection = None + except Exception: + pass + self.connected = False + + try: + # Re-establish connection using stored parameters + self.connection = ISOTCPConnection( + host=self.host, port=self.port, local_tsap=self.local_tsap, remote_tsap=self.remote_tsap + ) + self.connection.connect() + + # Re-create protocol to reset sequence counters + self.protocol = S7Protocol() + self._setup_communication() + + self.connected = True + self._is_alive = True + logger.info(f"Reconnected to {self.host}:{self.port}") + + if self._on_reconnect is not None: + try: + self._on_reconnect() + except Exception: + logger.debug("on_reconnect callback raised an exception", exc_info=True) + return + except Exception as e: + last_error = e + logger.warning(f"Reconnection attempt {attempt} failed: {e}") + + if attempt < self._max_retries: + # Exponential backoff with jitter + jitter = random.uniform(0, delay * 0.1) + sleep_time = min(delay + jitter, self._max_delay) + logger.debug(f"Waiting {sleep_time:.2f}s before next attempt") + time.sleep(sleep_time) + delay = min(delay * self._backoff_factor, self._max_delay) + + raise S7ConnectionError(f"Reconnection failed after {self._max_retries} attempts: {last_error}") + + def _start_heartbeat(self) -> None: + """Start the heartbeat background thread.""" + if self._heartbeat_interval <= 0: + return + + self._heartbeat_stop_event.clear() + self._is_alive = True + self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True, name="s7-heartbeat") + self._heartbeat_thread.start() + logger.debug(f"Heartbeat started with interval {self._heartbeat_interval}s") + + def _stop_heartbeat(self) -> None: + """Stop the heartbeat background thread.""" + self._heartbeat_stop_event.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=self._heartbeat_interval + 2) + self._heartbeat_thread = None + logger.debug("Heartbeat stopped") + + def _heartbeat_loop(self) -> None: + """Background loop that periodically probes the PLC connection.""" + while not self._heartbeat_stop_event.is_set(): + if self._heartbeat_stop_event.wait(timeout=self._heartbeat_interval): + break # Stop event was set + + if not self.connected: + self._is_alive = False + if self._auto_reconnect: + try: + self._do_reconnect() + except S7ConnectionError: + logger.warning("Heartbeat reconnection failed") + continue + + try: + with self._reconnect_lock: + if self.connected and self.connection is not None: + self.get_cpu_state() + self._is_alive = True + except Exception as e: + logger.warning(f"Heartbeat probe failed: {e}") + self._is_alive = False + self.connected = False + + if self._auto_reconnect: + try: + self._do_reconnect() + except S7ConnectionError: + logger.warning("Heartbeat reconnection failed") + def connect(self, address: str, rack: int, slot: int, tcp_port: int = 102) -> "Client": """ Connect to S7 PLC. @@ -187,9 +385,13 @@ def connect(self, address: str, rack: int, slot: int, tcp_port: int = 102) -> "C self._setup_communication() self.connected = True + self._is_alive = True self._exec_time = int((time.time() - start_time) * 1000) logger.info(f"Connected to {address}:{tcp_port} rack {rack} slot {slot}") + # Start heartbeat if configured + self._start_heartbeat() + except Exception as e: self.disconnect() if isinstance(e, S7Error): @@ -205,11 +407,15 @@ def disconnect(self) -> int: Returns: 0 on success """ + # Stop heartbeat first + self._stop_heartbeat() + if self.connection: self.connection.disconnect() self.connection = None self.connected = False + self._is_alive = False logger.info(f"Disconnected from {self.host}:{self.port}") return 0 @@ -359,11 +565,13 @@ def read_area(self, area: Area, db_number: int, start: int, size: int, word_len: max_chunk = self._max_read_size() if size <= max_chunk: - # Single request - request = self.protocol.build_read_request( - area=s7_area, db_number=db_number, start=start, word_len=s7_word_len, count=size - ) - response = self._send_receive(request) + # Single request - use reconnect-aware send/receive + def build_request() -> bytes: + return self.protocol.build_read_request( + area=s7_area, db_number=db_number, start=start, word_len=s7_word_len, count=size + ) + + response = self._send_receive_with_reconnect(build_request) values = self.protocol.extract_read_data(response, s7_word_len, size) self._exec_time = int((time.time() - start_time) * 1000) return bytearray(values) @@ -374,10 +582,14 @@ def read_area(self, area: Area, db_number: int, start: int, size: int, word_len: remaining = size while remaining > 0: chunk_size = min(remaining, max_chunk) - request = self.protocol.build_read_request( - area=s7_area, db_number=db_number, start=start + offset, word_len=s7_word_len, count=chunk_size - ) - response = self._send_receive(request) + chunk_offset = offset + + def build_chunk_request(o: int = chunk_offset, cs: int = chunk_size) -> bytes: + return self.protocol.build_read_request( + area=s7_area, db_number=db_number, start=start + o, word_len=s7_word_len, count=cs + ) + + response = self._send_receive_with_reconnect(build_chunk_request) values = self.protocol.extract_read_data(response, s7_word_len, chunk_size) result.extend(values) offset += chunk_size @@ -421,10 +633,12 @@ def write_area(self, area: Area, db_number: int, start: int, data: bytearray, wo max_chunk = self._max_write_size() if len(data) <= max_chunk: # Single request - request = self.protocol.build_write_request( - area=s7_area, db_number=db_number, start=start, word_len=s7_word_len, data=bytes(data) - ) - response = self._send_receive(request) + def build_request() -> bytes: + return self.protocol.build_write_request( + area=s7_area, db_number=db_number, start=start, word_len=s7_word_len, data=bytes(data) + ) + + response = self._send_receive_with_reconnect(build_request) self.protocol.check_write_response(response) self._exec_time = int((time.time() - start_time) * 1000) return 0 @@ -435,10 +649,14 @@ def write_area(self, area: Area, db_number: int, start: int, data: bytearray, wo while remaining > 0: chunk_size = min(remaining, max_chunk) chunk_data = data[offset : offset + chunk_size] - request = self.protocol.build_write_request( - area=s7_area, db_number=db_number, start=start + offset, word_len=s7_word_len, data=bytes(chunk_data) - ) - response = self._send_receive(request) + chunk_offset = offset + + def build_chunk_request(o: int = chunk_offset, cd: bytes = bytes(chunk_data)) -> bytes: + return self.protocol.build_write_request( + area=s7_area, db_number=db_number, start=start + o, word_len=s7_word_len, data=cd + ) + + response = self._send_receive_with_reconnect(build_chunk_request) self.protocol.check_write_response(response) offset += chunk_size remaining -= chunk_size diff --git a/tests/test_reconnect.py b/tests/test_reconnect.py new file mode 100644 index 00000000..f36293e0 --- /dev/null +++ b/tests/test_reconnect.py @@ -0,0 +1,447 @@ +"""Tests for connection heartbeat and automatic reconnection features.""" + +import logging +import threading +import time +import unittest + +import pytest + +from snap7.client import Client +from snap7.error import S7ConnectionError +from snap7.server import Server +from snap7.type import SrvArea + +logging.basicConfig(level=logging.WARNING) + +ip = "127.0.0.1" +tcpport = 1103 # Use different port to avoid conflict with test_client.py +db_number = 1 +rack = 1 +slot = 1 + + +@pytest.mark.client +class TestAutoReconnectDefaults(unittest.TestCase): + """Test that default behavior is unchanged when features are disabled.""" + + server: Server + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(100)) + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def test_default_auto_reconnect_disabled(self) -> None: + """Default client has auto_reconnect=False.""" + client = Client() + assert client._auto_reconnect is False + assert client._heartbeat_interval == 0 + + def test_default_client_works_normally(self) -> None: + """Default client connects and operates without new features interfering.""" + client = Client() + client.connect(ip, rack, slot, tcpport) + try: + data = client.db_read(db_number, 0, 4) + assert len(data) == 4 + finally: + client.disconnect() + + def test_is_alive_without_heartbeat(self) -> None: + """is_alive reflects connection state when heartbeat is disabled.""" + client = Client() + assert client.is_alive is False + + client.connect(ip, rack, slot, tcpport) + try: + assert client.is_alive is True + finally: + client.disconnect() + + assert client.is_alive is False + + def test_auto_reconnect_params_stored(self) -> None: + """Verify that auto-reconnect parameters are stored on the client.""" + client = Client( + auto_reconnect=True, + max_retries=5, + retry_delay=0.5, + backoff_factor=3.0, + max_delay=60.0, + ) + assert client._auto_reconnect is True + assert client._max_retries == 5 + assert client._retry_delay == 0.5 + assert client._backoff_factor == 3.0 + assert client._max_delay == 60.0 + + +@pytest.mark.client +class TestAutoReconnect(unittest.TestCase): + """Test automatic reconnection on connection loss.""" + + server: Server + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(100)) + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def test_reconnect_on_read_failure(self) -> None: + """Client reconnects transparently when a db_read fails due to connection loss.""" + client = Client(auto_reconnect=True, max_retries=3, retry_delay=0.1) + client.connect(ip, rack, slot, tcpport) + + try: + # Verify initial read works + data = client.db_read(db_number, 0, 4) + assert len(data) == 4 + + # Simulate connection loss by closing the socket + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + + # The next read should trigger reconnection and succeed + data = client.db_read(db_number, 0, 4) + assert len(data) == 4 + assert client.connected is True + finally: + client.disconnect() + + def test_reconnect_on_write_failure(self) -> None: + """Client reconnects transparently when a db_write fails due to connection loss.""" + client = Client(auto_reconnect=True, max_retries=3, retry_delay=0.1) + client.connect(ip, rack, slot, tcpport) + + try: + # Verify initial write works + client.db_write(db_number, 0, bytearray([1, 2, 3, 4])) + + # Simulate connection loss + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + + # The next write should trigger reconnection and succeed + client.db_write(db_number, 0, bytearray([5, 6, 7, 8])) + assert client.connected is True + + # Verify the data was written after reconnection + data = client.db_read(db_number, 0, 4) + assert data == bytearray([5, 6, 7, 8]) + finally: + client.disconnect() + + def test_no_reconnect_when_disabled(self) -> None: + """Without auto_reconnect, connection errors propagate immediately.""" + client = Client(auto_reconnect=False) + client.connect(ip, rack, slot, tcpport) + + try: + # Simulate connection loss + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + + with pytest.raises(S7ConnectionError): + client.db_read(db_number, 0, 4) + finally: + client.disconnect() + + def test_reconnect_callbacks(self) -> None: + """on_disconnect and on_reconnect callbacks are invoked.""" + disconnect_called = threading.Event() + reconnect_called = threading.Event() + + def on_disconnect() -> None: + disconnect_called.set() + + def on_reconnect() -> None: + reconnect_called.set() + + client = Client( + auto_reconnect=True, + max_retries=3, + retry_delay=0.1, + on_disconnect=on_disconnect, + on_reconnect=on_reconnect, + ) + client.connect(ip, rack, slot, tcpport) + + try: + # Simulate connection loss + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + + # Trigger reconnection via a read + data = client.db_read(db_number, 0, 4) + assert len(data) == 4 + + assert disconnect_called.is_set(), "on_disconnect was not called" + assert reconnect_called.is_set(), "on_reconnect was not called" + finally: + client.disconnect() + + def test_reconnect_max_retries_exhausted(self) -> None: + """S7ConnectionError is raised after max_retries are exhausted.""" + client = Client(auto_reconnect=True, max_retries=2, retry_delay=0.05) + client.connect(ip, rack, slot, tcpport) + + # Stop the server so reconnection will fail + self.__class__.server.stop() + + try: + # Simulate connection loss + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + + with pytest.raises(S7ConnectionError, match="Reconnection failed"): + client.db_read(db_number, 0, 4) + finally: + client.disconnect() + # Restart server for other tests + self.__class__.server = Server() + self.__class__.server.register_area(SrvArea.DB, 0, bytearray(100)) + self.__class__.server.register_area(SrvArea.DB, 1, bytearray(100)) + self.__class__.server.register_area(SrvArea.MK, 0, bytearray(100)) + self.__class__.server.start(tcp_port=tcpport) + + def test_connection_params_preserved_after_reconnect(self) -> None: + """Host, port, rack, slot are preserved and reused during reconnection.""" + client = Client(auto_reconnect=True, max_retries=3, retry_delay=0.1) + client.connect(ip, rack, slot, tcpport) + + try: + original_host = client.host + original_port = client.port + original_rack = client.rack + original_slot = client.slot + + # Simulate connection loss and trigger reconnect + if client.connection and client.connection.socket: + client.connection.socket.close() + client.connected = False + client.db_read(db_number, 0, 4) + + # Verify connection params are preserved + assert client.host == original_host + assert client.port == original_port + assert client.rack == original_rack + assert client.slot == original_slot + finally: + client.disconnect() + + +@pytest.mark.client +class TestHeartbeat(unittest.TestCase): + """Test heartbeat/watchdog functionality.""" + + server: Server + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(100)) + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def test_heartbeat_disabled_by_default(self) -> None: + """Heartbeat thread does not start when interval=0.""" + client = Client() + client.connect(ip, rack, slot, tcpport) + try: + assert client._heartbeat_thread is None + finally: + client.disconnect() + + def test_heartbeat_starts_and_stops(self) -> None: + """Heartbeat thread starts on connect and stops on disconnect.""" + client = Client(heartbeat_interval=0.5) + client.connect(ip, rack, slot, tcpport) + + try: + assert client._heartbeat_thread is not None + assert client._heartbeat_thread.is_alive() + assert client._heartbeat_thread.daemon is True + assert client.is_alive is True + finally: + client.disconnect() + + # After disconnect, thread should stop + assert client._heartbeat_thread is None + assert client.is_alive is False + + def test_heartbeat_detects_alive_connection(self) -> None: + """Heartbeat correctly reports connection as alive.""" + client = Client(heartbeat_interval=0.3) + client.connect(ip, rack, slot, tcpport) + + try: + # Wait for at least one heartbeat cycle + time.sleep(0.5) + assert client.is_alive is True + finally: + client.disconnect() + + def test_heartbeat_detects_dead_connection(self) -> None: + """Heartbeat sets is_alive=False when connection is lost.""" + client = Client(heartbeat_interval=0.3, auto_reconnect=False) + client.connect(ip, rack, slot, tcpport) + + try: + assert client.is_alive is True + + # Kill the connection without going through disconnect() + if client.connection and client.connection.socket: + client.connection.socket.close() + + # Wait for heartbeat to detect the failure + time.sleep(1.0) + assert client.is_alive is False + finally: + client.disconnect() + + def test_heartbeat_triggers_reconnect(self) -> None: + """When heartbeat fails and auto_reconnect is enabled, it reconnects.""" + reconnect_called = threading.Event() + + def on_reconnect() -> None: + reconnect_called.set() + + client = Client( + heartbeat_interval=0.3, + auto_reconnect=True, + max_retries=3, + retry_delay=0.1, + on_reconnect=on_reconnect, + ) + client.connect(ip, rack, slot, tcpport) + + try: + # Kill the connection + if client.connection and client.connection.socket: + client.connection.socket.close() + + # Wait for heartbeat to detect and trigger reconnect + reconnect_called.wait(timeout=3.0) + assert reconnect_called.is_set(), "Heartbeat did not trigger reconnection" + + # Give some time for the reconnect to complete fully + time.sleep(0.5) + assert client.is_alive is True + assert client.connected is True + + # Verify connection works after heartbeat-triggered reconnect + data = client.db_read(db_number, 0, 4) + assert len(data) == 4 + finally: + client.disconnect() + + def test_context_manager_stops_heartbeat(self) -> None: + """Heartbeat is properly stopped when using context manager.""" + with Client(heartbeat_interval=0.3) as client: + client.connect(ip, rack, slot, tcpport) + assert client._heartbeat_thread is not None + + # After context exit, heartbeat should be stopped + assert client._heartbeat_thread is None + + +@pytest.mark.client +class TestBackwardCompatibility(unittest.TestCase): + """Ensure the new features don't break backward compatibility.""" + + server: Server + + @classmethod + def setUpClass(cls) -> None: + cls.server = Server() + cls.server.register_area(SrvArea.DB, 0, bytearray(100)) + cls.server.register_area(SrvArea.DB, 1, bytearray(100)) + cls.server.register_area(SrvArea.PA, 0, bytearray(100)) + cls.server.register_area(SrvArea.PE, 0, bytearray(100)) + cls.server.register_area(SrvArea.MK, 0, bytearray(100)) + cls.server.start(tcp_port=tcpport) + + @classmethod + def tearDownClass(cls) -> None: + if cls.server: + cls.server.stop() + cls.server.destroy() + + def test_old_init_signature_still_works(self) -> None: + """Client() and Client(lib_location=None) still work.""" + c1 = Client() + assert c1._auto_reconnect is False + + c2 = Client(lib_location=None) + assert c2._auto_reconnect is False + + c3 = Client(lib_location="/some/path") + assert c3._auto_reconnect is False + + def test_read_write_without_reconnect(self) -> None: + """Standard read/write operations work without reconnect enabled.""" + client = Client() + client.connect(ip, rack, slot, tcpport) + try: + # Write + client.db_write(db_number, 0, bytearray([10, 20, 30, 40])) + # Read + data = client.db_read(db_number, 0, 4) + assert data == bytearray([10, 20, 30, 40]) + finally: + client.disconnect() + + def test_get_connected(self) -> None: + """get_connected still works correctly.""" + client = Client() + assert client.get_connected() is False + + client.connect(ip, rack, slot, tcpport) + try: + assert client.get_connected() is True + finally: + client.disconnect() + + assert client.get_connected() is False + + def test_mb_read_write(self) -> None: + """Marker area read/write works with reconnect-aware code path.""" + client = Client(auto_reconnect=True, max_retries=1, retry_delay=0.1) + client.connect(ip, rack, slot, tcpport) + try: + client.mb_write(0, 4, bytearray([0xAA, 0xBB, 0xCC, 0xDD])) + data = client.mb_read(0, 4) + assert data == bytearray([0xAA, 0xBB, 0xCC, 0xDD]) + finally: + client.disconnect()