diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 294346a..2e55a75 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -21,7 +21,10 @@ def __init__( ): super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler - self._queryable = self._session.declare_queryable(self._channel, self._on_query) + self._queryable = self._session.declare_queryable(self._channel, + self._on_query, + complete=False) + print(f"Declared queryable on channel: {self._channel}") def core_registration(self): print("..todo: register with ark core..") @@ -29,22 +32,26 @@ def core_registration(self): def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries if not self._active: + print("Received query on closed Queryable, ignoring") return try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) - raw = bytes(query.value) if query.value is not None else b"" + raw = bytes(query.payload) if query.payload is not None else b"" if not raw: + print("Received query with no payload, ignoring") return # nothing to do req_env = Envelope() req_env.ParseFromString(raw) # Decode request protobuf - req_type = msgs.get(req_env.payload_msg_type) + # req_type = msgs.get(req_env.payload_msg_type) + req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) + print(f"Unknown message type '{req_env.msg_type}' in query, ignoring") return req_msg = req_type() @@ -60,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None: resp_env.sent_seq_index = self._seq_index resp_env.src_node_name = self._node_name resp_env.channel = self._channel + resp_env.msg_type = resp_msg.DESCRIPTOR.full_name + resp_env.payload = resp_msg.SerializeToString() self._seq_index += 1 - resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(resp_env.SerializeToString()) + with query: + query.reply(query.key_expr, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) @@ -73,4 +82,9 @@ def _on_query(self, query: zenoh.Query) -> None: except Exception: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired + print("Error processing query:") + # write the traceback to stdout for debugging + import traceback + traceback.print_exc() + return diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c6d4586..88a6e8a 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -3,6 +3,7 @@ from google.protobuf.message import Message from ark.data.data_collector import DataCollector from ark.comm.end_point import EndPoint +from ark_msgs.registry import msgs class Querier(EndPoint): @@ -11,12 +12,16 @@ def __init__( self, node_name: str, session: zenoh.Session, + query_target, clock, channel: str, data_collector: DataCollector | None, ): super().__init__(node_name, session, clock, channel, data_collector) - self._querier = self._session.declare_querier(self._channel) + self._querier = self._session.declare_querier(self._channel, + target=query_target) + print(f"Declared querier on channel: {self._channel}") + self._query_selector = zenoh.Selector(self._channel) def core_registration(self): print("..todo: register with ark core..") @@ -48,18 +53,21 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout) + replies = self._querier.get(payload=req_env.SerializeToString()) for reply in replies: if reply.ok is None: continue resp_env = Envelope() - resp_env.ParseFromString(bytes(reply.ok)) + resp_env.ParseFromString(bytes(reply.ok.payload)) resp_env.dst_node_name = self._node_name resp_env.recv_timestamp = self._clock.now() - resp = resp_env.extract_message() + try: + resp = resp_env.extract_message() + except Exception as e: + continue self._seq_index += 1 @@ -69,11 +77,6 @@ def query( return resp - else: - raise TimeoutError( - f"No OK reply received for query on '{self._channel}' within {timeout}s" - ) - def close(self): super().close() self._querier.undeclare() diff --git a/src/ark/node.py b/src/ark/node.py index 6ba6c6e..24a8df0 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -22,7 +22,8 @@ def __init__( sim: bool = False, collect_data: bool = False, ): - self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + # self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + self._z_cfg = z_cfg self._session = zenoh.open(self._z_cfg) self._env_name = env_name self._node_name = node_name @@ -73,17 +74,19 @@ def create_subscriber(self, channel, callback) -> Subscriber: self._subs[channel] = sub return sub - def create_querier(self, channel, timeout=10.0) -> Querier: + def create_querier(self, channel, target, timeout=10.0) -> Querier: querier = Querier( self._node_name, self._session, + target, self._clock, channel, self._data_collector, - timeout, + # timeout, ) querier.core_registration() self._queriers[channel] = querier + # print session and channelinfo for debugging return querier def create_queryable(self, channel, handler) -> Queryable: diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py new file mode 100644 index 0000000..314f23a --- /dev/null +++ b/test/ad_plotter_sub.py @@ -0,0 +1,124 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, Value + +# from common import connect_cfg, z_cfg +import argparse +import zenoh +import common_example as common + + +class AutodiffPlotterNode(BaseNode): + def __init__(self, cfg, target): + super().__init__("env", "autodiff_plotter", cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.grad_vx, self.grad_my = [], [] + self.create_subscriber("position", self.on_position) + self.grad_vx_querier = self.create_querier("grad/v/x", target=target) + self.grad_my_querier = self.create_querier("grad/m/y", target=target) + + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + + def fetch_grads(self): + req = Value() + try: + resp_vx = self.grad_vx_querier.query(req) + if isinstance(resp_vx, Value): + print(f"Received grad_vx: {resp_vx.grad}") + self.grad_vx.append(resp_vx.grad) + except Exception: + pass + try: + resp_my = self.grad_my_querier.query(req) + if isinstance(resp_my, Value): + print(f"Received grad_my: {resp_my.grad}") + self.grad_my.append(resp_my.grad) + except Exception: + pass + + +def main(): + + # These are a few zenoh config related arguments that were taken from the + # examples, keeping them there until we have a better way to manage configs + # across examples + parser = argparse.ArgumentParser(description="Autodiff Plotter Node") + common.add_config_arguments(parser) + parser.add_argument( + "--target", + "-t", + dest="target", + choices=["ALL", "BEST_MATCHING", "ALL_COMPLETE", "NONE"], + default="BEST_MATCHING", + type=str, + help="The target queryables of the query.", + ) + parser.add_argument( + "--timeout", + "-o", + dest="timeout", + default=10.0, + type=float, + help="The query timeout", + ) + parser.add_argument( + "--iter", dest="iter", type=int, help="How many gets to perform" + ) + parser.add_argument( + "--add-matching-listener", + default=False, + action="store_true", + help="Add matching listener", + ) + + args = parser.parse_args() + conf = common.get_config_from_args(args) + + # These were required for the querier and queryable to find each other. + target = { + "ALL": zenoh.QueryTarget.ALL, + "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, + "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, + }.get(args.target) + + # Main subcription and querying loop + node = AutodiffPlotterNode(conf, target) + threading.Thread(target=node.spin, daemon=True).start() + + # Plotting trajectory and gradients + fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-5, 5) + ax_pos.set_ylim(-5, 5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_grad.set_title("Gradients") + ax_grad.set_xlabel("t") + ax_grad.set_ylabel("grad") + ax_grad.set_xlim(-5, 5) + ax_grad.set_ylim(-5, 5) + (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") + (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") + ax_grad.legend() + + def update(frame): + node.fetch_grads() + line_pos.set_data(node.pos_x, node.pos_y) + line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) + line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) + return line_pos, line_grad_vx, line_grad_my + + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() + + +if __name__ == "__main__": + main() diff --git a/test/autodiff.py b/test/autodiff.py new file mode 100644 index 0000000..b66b6da --- /dev/null +++ b/test/autodiff.py @@ -0,0 +1,54 @@ +import numpy as np + +class Value: + def __init__(self, val, parents=(), backward=None, name=None): + self.val = np.asarray(val, dtype=float) + self.grad = np.zeros_like(self.val) + self._prev = parents + self._backward = backward or (lambda: None) + self.name = name + def backward(self, grad=None): + if grad is None: + grad = np.ones_like(self.val) + self.grad = self.grad + grad + topo = [] + visited = set() + def build(v): + if v not in visited: + visited.add(v) + for p in v._prev: + build(p) + topo.append(v) + build(self) + for v in reversed(topo): + v._backward() + def __add__(self, other): + out = Value(self.val + other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad + out.grad + out._backward = _backward + return out + def __sub__(self, other): + out = Value(self.val - other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad - out.grad + out._backward = _backward + return out + def __mul__(self, other): + out = Value(self.val * other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + other.val * out.grad + other.grad = other.grad + self.val * out.grad + out._backward = _backward + return out + def __neg__(self): + out = Value(-self.val, parents=(self,)) + def _backward(): + self.grad = self.grad - out.grad + out._backward = _backward + return out +def clear_grads(params): + for p in params: + p.grad = np.zeros_like(p.val) diff --git a/test/common.py b/test/common.py index b26b213..0a5d7bc 100644 --- a/test/common.py +++ b/test/common.py @@ -1 +1,17 @@ -z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}} +listen_cfg = { + "mode": "peer", + "listen": { + "endpoints": ["tcp/0.0.0.0:7447"]}, +} +connect_cfg = { + "mode": "peer", + "connect": { + "endpoints": ["tcp/127.0.0.1:7447"] + } +} +z_cfg = { + "mode": "peer", + # "connect": { + # "endpoints":["udp/127.0.0.1:7447"] + # } +} diff --git a/test/common_example.py b/test/common_example.py new file mode 100644 index 0000000..0c1eea3 --- /dev/null +++ b/test/common_example.py @@ -0,0 +1,83 @@ +import argparse +import json + +import zenoh + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--mode", + "-m", + dest="mode", + choices=["peer", "client"], + type=str, + help="The zenoh session mode.", + ) + parser.add_argument( + "--connect", + "-e", + dest="connect", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to connect to.", + ) + parser.add_argument( + "--listen", + "-l", + dest="listen", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to listen on.", + ) + parser.add_argument( + "--config", + "-c", + dest="config", + metavar="FILE", + type=str, + help="A configuration file.", + ) + parser.add_argument( + "--no-multicast-scouting", + dest="no_multicast_scouting", + default=False, + action="store_true", + help="Disable multicast scouting.", + ) + parser.add_argument( + "--cfg", + dest="cfg", + metavar="CFG", + default=[], + action="append", + type=str, + help="Allows arbitrary configuration changes as column-separated KEY:VALUE pairs. Where KEY must be a valid config path and VALUE must be a valid JSON5 string that can be deserialized to the expected type for the KEY field. Example: --cfg='transport/unicast/max_links:2'.", + ) + + +def get_config_from_args(args) -> zenoh.Config: + conf = ( + zenoh.Config.from_file(args.config) + if args.config is not None + else zenoh.Config() + ) + if args.mode is not None: + conf.insert_json5("mode", json.dumps(args.mode)) + if args.connect is not None: + conf.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen is not None: + conf.insert_json5("listen/endpoints", json.dumps(args.listen)) + if args.no_multicast_scouting: + conf.insert_json5("scouting/multicast/enabled", json.dumps(False)) + + for c in args.cfg: + try: + [key, value] = c.split(":", 1) + except: + print(f"`--cfg` argument: expected KEY:VALUE pair, got {c}") + raise + conf.insert_json5(key, value) + + return conf diff --git a/test/diff_publisher.py b/test/diff_publisher.py new file mode 100644 index 0000000..aa74ee8 --- /dev/null +++ b/test/diff_publisher.py @@ -0,0 +1,142 @@ +import math +import time +from ark.node import BaseNode +from ark_msgs import Translation, Value +import argparse + +# from common import listen_cfg, z_cfg +import common_example as common +import torch + +# Lissajous parameters +A, B = 1.0, 1.0 +a, b = 3.0, 2.0 +delta = math.pi / 2 +HZ = 50 +DT = 1.0 / HZ + + +class LissajousPublisherNode(BaseNode): + def __init__(self): + super().__init__("env", "diff_pub", listen_cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.vel_pub = self.create_publisher("velocity") + self.rate = self.create_rate(HZ) + + def spin(self): + t = 0.0 + while True: + x = A * math.sin(a * t + delta) + y = B * math.sin(b * t) + dx = A * a * math.cos(a * t + delta) + dy = B * b * math.cos(b * t) + self.pos_pub.publish(Translation(x=x, y=y, z=0.0)) + self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) + t += DT + self.rate.sleep() + + +class LinePublisherNode(BaseNode): + + def __init__(self, cfg): + super().__init__("env", "line_pub", cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.rate = self.create_rate(HZ) + self.v = torch.tensor(1.0, requires_grad=True) + self.m = torch.tensor(0.5, requires_grad=True) + self.c = torch.tensor(0.0, requires_grad=True) + self.latest = { + "x": 0.0, + "y": 0.0, + "v_x": 0.0, + "v_y": 0.0, + "m_x": 0.0, + "m_y": 0.0, + "c_x": 0.0, + "c_y": 0.0, + } + # declare and store all the queryables for gradients + self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + + def _on_grad_v_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["v_x"]) + + def _on_grad_v_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["v_y"]) + + def _on_grad_m_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["m_x"]) + + def _on_grad_m_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["m_y"]) + pass + + def _on_grad_c_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["c_x"]) + + def _on_grad_c_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["c_y"]) + + def spin(self): + t = 0.0 + while True: + t_val = torch.tensor(t, requires_grad=False) + + # Computation graph + # line equation: y = m * x + c, where x = v * t + x = self.v * t_val + y = self.m * x + self.c + + # publish position + self.pos_pub.publish( + Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) + ) + + # compute gradients + if self.v.grad is not None: + self.v.grad.zero_() + if self.m.grad is not None: + self.m.grad.zero_() + if self.c.grad is not None: + self.c.grad.zero_() + x.backward(retain_graph=True) + self.latest["v_x"] = float(self.v.grad) + print(f"Grad v_x {self.v.grad.item()}") + self.v.grad.zero_() + y.backward() + self.latest["v_y"] = float(self.v.grad) + self.latest["m_y"] = float(self.m.grad) + print(f"Grad m_y {self.m.grad}") + self.latest["c_y"] = float(self.c.grad) + self.latest["x"] = float(x.detach()) + self.latest["y"] = float(y.detach()) + t += DT + self.rate.sleep() + + +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser( + prog="z_queryable", description="zenoh queryable example" + ) + common.add_config_arguments(parser) + parser.add_argument( + "--complete", + dest="complete", + default=False, + action="store_true", + help="Declare the queryable as complete w.r.t. the key expression.", + ) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = LinePublisherNode(conf) + node.spin() + except KeyboardInterrupt: + print("Shutting down diff publisher.") + node.close() diff --git a/test/gradient_exp.md b/test/gradient_exp.md new file mode 100644 index 0000000..54387b0 --- /dev/null +++ b/test/gradient_exp.md @@ -0,0 +1,47 @@ +# Gradient Experiment + +Demonstrates differentiable simulation using ark framework. A `LinePublisherNode` publishes position on a line (`y = m*x + c`, `x = v*t`) along with autograd gradients (dx/dv, dy/dm, dy/dc), and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. + +## Prerequisites + +- Install ark framework and dependencies (`zenoh`, `torch`, `matplotlib`, `ark_msgs`) +- Run all commands from the `test/` directory + +## Running the Experiment + +Open three separate terminals. All commands are run from the `test/` directory. + +### Shell 1 — Sim Clock + +Drives simulated time for all sim-enabled nodes. + +```bash +cd test +python simstep.py +``` + +### Shell 2 — Diff Publisher + +Publishes position (`Translation`) and serves gradient queryables (`grad/v/x`, `grad/m/y`, etc.). + +```bash +cd test +python diff_publisher.py +``` + +### Shell 3 — Autodiff Plotter + +Subscribes to position and queries gradients, then plots both in real time. + +```bash +cd test +python ad_plotter_sub.py +``` + +## What to Expect + +- **Shell 1** prints the simulated time advancing each tick. +- **Shell 2** prints computed gradients (`Grad v_x`, `Grad m_y`) each step. +- **Shell 3** opens a matplotlib window with two plots: + - **Left**: Position trajectory (x vs y). + - **Right**: Gradients over time (dx/dv in green, dy/dm in magenta). diff --git a/test/plotter_subsriber.py b/test/plotter_subsriber.py new file mode 100644 index 0000000..2477962 --- /dev/null +++ b/test/plotter_subsriber.py @@ -0,0 +1,47 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +class SubscriberPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.vel_x, self.vel_y = [], [] + self.create_subscriber("position", self.on_position) + self.create_subscriber("velocity", self.on_velocity) + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def on_velocity(self, msg: dTranslation): + self.vel_x.append(msg.x) + self.vel_y.append(msg.y) +def main(): + node = SubscriberPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_vel) = plt.subplots(1, 2, figsize=(10, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-1.5, 1.5) + ax_pos.set_ylim(-1.5, 1.5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_vel.set_title("Velocity (dTranslation)") + ax_vel.set_xlabel("dx") + ax_vel.set_ylabel("dy") + ax_vel.set_xlim(-5, 5) + ax_vel.set_ylim(-5, 5) + ax_vel.set_aspect("equal") + (line_vel,) = ax_vel.plot([], [], "r-") + def update(frame): + line_pos.set_data(node.pos_x, node.pos_y) + line_vel.set_data(node.vel_x, node.vel_y) + return line_pos, line_vel + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/simstep.py b/test/simstep.py new file mode 100644 index 0000000..06da237 --- /dev/null +++ b/test/simstep.py @@ -0,0 +1,18 @@ +from ark.time.simtime import SimTime +from common import z_cfg +import json +import zenoh +import time + +def main(): + z_config = zenoh.Config.from_json5(json.dumps(z_cfg)) + with zenoh.open(z_config) as z: + sim_time = SimTime(z, "clock", 1000) + sim_time.reset() + while True: + current_time = time.time() + print(f"Simulated Time: {current_time:.2f} seconds") + sim_time.tick() + +if __name__ == "__main__": + main()