diff --git a/rocketpy/simulation/events.py b/rocketpy/simulation/events.py index 13b775725..24e650dd0 100644 --- a/rocketpy/simulation/events.py +++ b/rocketpy/simulation/events.py @@ -1,28 +1,34 @@ +import inspect +from typing import get_type_hints + + class Event: + """A class representing an event in the simulation.""" + # TODO: should "sensors" arg of the trigger function be a dictionary instead - # of a list? It would be more intuitive to access the sensors by name + # of a list? It would be more intuitive to access the sensors by name def __init__(self, trigger, action, name, event_context=None): """Initializes an Event object. Parameters ---------- trigger : function - A function that must return a boolean value. The event will be - triggered when this function returns True. The function should be + A function that must return a boolean value. The event will be + triggered when this function returns True. The function should be defined with the following signature: trigger(**kwargs), where kwargs is a dictionary containing the keys: - `"time"` (float): The current simulation time in seconds. - `"state"` (list): The state vector of the simulation, structured as `[x, y, z, vx, vy, vz, e0, e1, e2, e3, wx, wy, wz]`. - - `"state_dot"` (list): The time derivative of the state vector, + - `"state_dot"` (list): The time derivative of the state vector, structured as `[vx, vy, vz, ax, ay, az, e0_dot, e1_dot, e2_dot, e3_dot, wx_dot, wy_dot, wz_dot]`. - `"sampling_rate"` (float or None): The sampling rate of the event, in seconds. If None, the event will be checked for triggering at every time step of the simulation. If a float - value is provided, the event will only be checked for + value is provided, the event will only be checked for triggering at that specific time interval. - - `"sensors"` (list): A list of sensors that are attached to the + - `"sensors"` (list): A list of sensors that are attached to the rocket. The most recent measurements of the sensors are provided with the ``sensor.measurement`` attribute. The sensors are listed in the same order as they are added to the rocket. @@ -34,20 +40,20 @@ def __init__(self, trigger, action, name, event_context=None): - `"phase_index"` (int): The index of the current flight phase. - `"node_index"` (int): The index of the current node in the current flight phase. - - Any additional custom key-value pairs provided via the + - Any additional custom key-value pairs provided via the `event_context` parameter (see below). - + action : function A function that will be executed when the event is triggered. The - function should be defined with the following signature: + function should be defined with the following signature: action(**kwargs), where kwargs is a dictionary containing the same keys as the trigger function. The action function can also modify the state of the simulation by returning a dictionary with the keys: - `"state"` (list): A new state vector to replace the current state vector. The structure of the state vector is the same as the one provided in the trigger function. - - `"disable_event"` (bool): If True, the event will not be - checked for triggering again after being triggered, making + - `"disable_event"` (bool): If True, the event will not be + checked for triggering again after being triggered, making it a one-time event. Defaults to True. - `"new_events"` (list): A list of new Event objects to be added to the simulation when the event is triggered. This can be @@ -55,50 +61,48 @@ def __init__(self, trigger, action, name, event_context=None): triggered, such as a parachute deployment event that spawns a new event to check for the parachute deployment after a certain time delay. - - `"remove_events"` (list): A list of Event objects to be - removed from the simulation when the event is triggered. This - can be used to create events that remove other events when - they are triggered, such as a parachute deployment event that + - `"remove_events"` (list): A list of Event objects to be + removed from the simulation when the event is triggered. This + can be used to create events that remove other events when + they are triggered, such as a parachute deployment event that removes the apogee event when it is triggered. - - Any other key-value pairs defined in `event_context` will - also be included. These allow you to maintain custom state or - counters across multiple trigger and action calls. Use cases + - Any other key-value pairs defined in `event_context` will + also be included. These allow you to maintain custom state or + counters across multiple trigger and action calls. Use cases include: tracking the number of times an event has been triggered - (e.g., `{"trigger_count": 0}`), recording the time of the last - trigger (e.g., `{"last_trigger_time": None}`), or any other + (e.g., `{"trigger_count": 0}`), recording the time of the last + trigger (e.g., `{"last_trigger_time": None}`), or any other custom data your trigger/action functions need to share state. - - Example: If you initialize the event with - `event_context={"trigger_count": 0}`, your trigger and action - functions will receive `trigger_count=0` in their kwargs dict. - You can then update this value in the action function by - including it in the returned dictionary (e.g., - `{"trigger_count": 1}`), and it will be passed to subsequent + + Example: If you initialize the event with + `event_context={"trigger_count": 0}`, your trigger and action + functions will receive `trigger_count=0` in their kwargs dict. + You can then update this value in the action function by + including it in the returned dictionary (e.g., + `{"trigger_count": 1}`), and it will be passed to subsequent trigger/action calls. name : str A name for the event, used for identification purposes. event_context : dict, optional - A dictionary of custom key-value pairs that will be passed to the - trigger and action functions. This allows you to initialize and - maintain custom state that persists across multiple trigger/action - calls. For example, `event_context={"trigger_count": 0, - "last_trigger_time": None}` can be used to track event state. - When the action function returns a dictionary with updated values - (e.g., `{"trigger_count": 1}`), those values persist and are - passed to subsequent calls. Defaults to an empty dictionary if not + A dictionary of custom key-value pairs that will be passed to the + trigger and action functions. This allows you to initialize and + maintain custom state that persists across multiple trigger/action + calls. For example, `event_context={"trigger_count": 0, + "last_trigger_time": None}` can be used to track event state. + When the action function returns a dictionary with updated values + (e.g., `{"trigger_count": 1}`), those values persist and are + passed to subsequent calls. Defaults to an empty dictionary if not provided. """ + self.name = name self.trigger = self.__verify_trigger(trigger) self.action = self.__verify_action(action) - self.name = name self.event_context = event_context if event_context is not None else {} - # TODO: implement tracking for whether this event is currently enabled # or disabled. The disable_event flag from the action return value should # control whether this event continues to be checked for triggering. - # TODO: check_trigger does note receive enough arguments to substitute parachute events def __verify_trigger(self, trigger): """Verifies that the trigger function is valid. @@ -115,15 +119,30 @@ def __verify_trigger(self, trigger): Raises ------ ValueError - If the trigger function does not have the correct signature or does not return a boolean value. + If the trigger function does not have the correct signature or does not return a boolean value + (at least if not declared or annotated). """ - # TODO: implement inspection of trigger function to verify: - # 1. It accepts **kwargs (accepts arbitrary keyword arguments) - # 2. Return type annotation is bool or can be tested to return bool - # 3. Consider allowing signature to be flexible (accepts **kwargs) - # to accommodate user-defined custom event_context keys + # verify if the trigger function accepts only **kwargs arguments + # also avoids functions with no arguments, since they can't be used as triggers + s = inspect.signature(trigger) + if ( + any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()) + or len(s.parameters) == 0 + ): + raise ValueError( + f"The Trigger function of the {self.name} event must accept only keyword arguments. def {trigger.__name__}(**kwargs) -> bool:" + ) + # Verify if the return type annotation is bool. + # Since is not possible to know for sure if the user is actually returning a bool value, + # we enforce bool annotation to motivate users to actually return bool values. + return_annotation = get_type_hints(trigger).get("return", None) + if return_annotation is not bool: + raise ValueError( + f"The Trigger function of the {self.name} event must return a boolean value and must be annotated with '-> bool' for type checking.\n" + f"def {trigger.__name__}(**kwargs) -> bool:" + ) return trigger - + def __verify_action(self, action): """Verifies that the action function is valid. @@ -140,17 +159,30 @@ def __verify_action(self, action): Raises ------ ValueError - If the action function does not have the correct signature. + If the action function does not have the correct signature or does not return a valid type. """ - # TODO: implement inspection of action function to verify: - # 1. It accepts **kwargs (accepts arbitrary keyword arguments) - # 2. It can optionally return None or a dict with any of these keys: - # - \"state\": list of floats - # - \"disable_event\": bool - # - \"new_events\": list of Event objects - # - \"remove_events\": list of Event objects - # - Any custom keys to update event_context - # 3. Raise ValueError if signature doesn't match expectations + # verify if the action function accepts only **kwargs arguments + s = inspect.signature(action) + if ( + any(p.kind != inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()) + or len(s.parameters) == 0 + ): + raise ValueError( + f"The Action function of the {self.name} event must accept only keyword arguments. def {action.__name__}(**kwargs) -> None | dict:" + ) + # verify if the return type annotation is None or dict + # Since is not possible to know for sure if the user is actually returning None or a dict, + # we enforce None or dict annotation to motivate users to actually return None or dict. + return_annotation = get_type_hints(action).get("return", int) + if ( + (return_annotation is not type(None)) + and (return_annotation is not dict) + and (return_annotation is not bool) + ): + raise ValueError( + f"The Action function of the {self.name} event must return None or a dictionary and must be annotated with '-> None' or '-> dict' for type checking.\n" + f"def {action.__name__}(**kwargs) -> None | dict:" + ) return action def __repr__(self): @@ -191,21 +223,20 @@ def __call__(self, *args, **kwds): pass - # TODO: add a parameter to the Event class that specify whether the event should -# be triggered only once, or if it can be triggered multiple times. Also, add a +# be triggered only once, or if it can be triggered multiple times. Also, add a # way to stop the event from continuously triggering on command inside the action -# function, such as a "disable" method that can be called inside the action +# function, such as a "disable" method that can be called inside the action # function to prevent the event from being triggered again. # TODO: add a parameter to the Event class that specify whether the event should # be a discrete event, meaning that it should only be checked for triggering at # specific time intervals (e.g. every 0.1 seconds) instead of at every time step -# of the simulation. This would be useful for parachute events. This should be +# of the simulation. This would be useful for parachute events. This should be # done by adding a "sampling_rate" parameter to the Event class, that is none by # default (meaning that the event is checked at every time step), but if it is -# set to a float value, the event will only be checked for triggering at that -# specific time interval. The flight class should be able to differentiate +# set to a float value, the event will only be checked for triggering at that +# specific time interval. The flight class should be able to differentiate # between the discrete and continuous events (we will handle this later) @@ -219,4 +250,4 @@ def __call__(self, *args, **kwds): # - Respect the disable_event flag and sampling_rate to control when events # are checked for triggering # - Handle the sampling_rate logic: only check events at the specified intervals, -# not at every simulation time step \ No newline at end of file +# not at every simulation time step diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 9e78d169b..2e38a6378 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -619,7 +619,8 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.ode_solver = ode_solver # Events - def out_of_rail_trigger(state): + def out_of_rail_trigger(**kwargs) -> bool: + state = kwargs["state"] return ( state[0] ** 2 + state[1] ** 2 + (state[2] - self.env.elevation) ** 2 >= self.effective_1rl**2 @@ -1031,8 +1032,10 @@ def __check_simulation_events(self, phase, phase_index, node_index): # TODO: make all these 3 events be handled with the Events class # Check for first out of rail event if len(self.out_of_rail_state) == 1: - if self.out_of_rail_event.trigger(self.y_sol): - return self.out_of_rail_event.action(phase, phase_index, node_index) + if self.out_of_rail_event.trigger(state=self.y_sol): + return self.out_of_rail_event.action( + phase=phase, phase_index=phase_index, node_index=node_index + ) # Check for apogee event # TODO: negative vz doesn't really mean apogee. Improve this. @@ -1045,10 +1048,10 @@ def __check_simulation_events(self, phase, phase_index, node_index): return False - def __handle_out_of_rail_event(self, phase, phase_index, node_index): + def __handle_out_of_rail_event(self, **kwargs) -> bool: """Handle the out of rail event. - Parameters + keyword arguments are passed by the Event class when the trigger function is called. ---------- phase : FlightPhase The current flight phase. @@ -1062,6 +1065,9 @@ def __handle_out_of_rail_event(self, phase, phase_index, node_index): bool True to indicate the simulation should break. """ + phase = kwargs["phase"] + phase_index = kwargs["phase_index"] + node_index = kwargs["node_index"] # Check exactly when it went out using root finding # Disconsider elevation self.solution[-2][3] -= self.env.elevation diff --git a/tests/unit/simulation/test_events.py b/tests/unit/simulation/test_events.py new file mode 100644 index 000000000..9e569eeb8 --- /dev/null +++ b/tests/unit/simulation/test_events.py @@ -0,0 +1,165 @@ +import pytest + +from rocketpy.simulation.events import Event + + +def test_verify_trigger_accepts_only_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.trigger is trigger + + +def test_verify_trigger_evaluation_of_number_of_parameters(): + def trigger(**kwargs) -> bool: + a = kwargs["a"] + b = kwargs["b"] + c = kwargs["c"] + return a + b + c == 6 + + def action(**kwargs) -> None: + return None + + kwargs_test = {"a": 1, "b": 2, "c": 3} + assert trigger(**kwargs_test) + + event = Event(trigger=trigger, action=action, name="test") + assert event.trigger is trigger + + +def test_verify_trigger_rejects_missing_kwargs(): + def trigger(a, b) -> bool: + return True + + def action(**kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_args_with_kwargs(): + def trigger(a, b, **kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_triggers_with_no_parameters(): + def trigger() -> bool: + return True + + def action(**kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Trigger function of the test event must accept only keyword arguments. def trigger\(\*\*kwargs\) -> bool:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_trigger_rejects_triggers_without_bool_return_annotation(): + def trigger(**kwargs): + return True + + def action(**kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match="The Trigger function of the test event must return a boolean value and must be annotated with '-> bool' for type checking.\n" + + r"def trigger\(\*\*kwargs\) -> bool\:", + ): + Event(trigger=trigger, action=action, name="test") + + +# The following tests verify if action functions were correctly implemented + + +def test_verify_action_accepts_only_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +def test_verify_action_rejects_missing_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(a, b) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Action function of the test event must accept only keyword arguments. def action\(\*\*kwargs\) -> None \| dict:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_action_rejects_args_with_kwargs(): + def trigger(**kwargs) -> bool: + return True + + def action(a, b, **kwargs) -> None: + return None + + with pytest.raises( + ValueError, + match=r"The Action function of the test event must accept only keyword arguments. def action\(\*\*kwargs\) -> None \| dict:", + ): + Event(trigger=trigger, action=action, name="test") + + +def test_verify_action_accepts_dict_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> dict: + return {"key": "value"} + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +def test_verify_action_accepts_none_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> None: + return None + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action + + +# this was also allowed because some actions functions already return bool, they need to be updated +# then this test can be removed and the check for bool return type can be removed from the events.py file +def test_verify_action_accepts_bool_return_type(): + def trigger(**kwargs) -> bool: + return True + + def action(**kwargs) -> bool: + return True + + event = Event(trigger=trigger, action=action, name="test") + assert event.action is action