From 0d9923fa242e0771d1f30363b73e5c63d0896bf9 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Mon, 9 Feb 2026 18:03:12 +0000 Subject: [PATCH 1/8] adds diff test code --- test/diff_pub.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/diff_pub.py diff --git a/test/diff_pub.py b/test/diff_pub.py new file mode 100644 index 0000000..40ff663 --- /dev/null +++ b/test/diff_pub.py @@ -0,0 +1,27 @@ +from ark.node import BaseNode +from itertools import count +from common import z_cfg + + +class PublisherNode(BaseNode): + + def __init__(self): + super().__init__("env", "pub", z_cfg, sim=True) + self.pub = self.create_publisher("diff_sim") + self.rate = self.create_rate(1) # 1 Hz + + def spin(self): + for c in count(): + msg = f"Hello World {c}" + self.pub.publish(msg.encode("utf-8")) + print(f"Published: {msg}") + self.rate.sleep() + + +if __name__ == "__main__": + try: + node = PublisherNode() + node.spin() + except KeyboardInterrupt: + print("Shutting down publisher node.") + node.close() From 5214e1c4dac385c52778ef4ea002aa955d63c2cd Mon Sep 17 00:00:00 2001 From: kamiradi Date: Tue, 10 Feb 2026 14:17:14 +0000 Subject: [PATCH 2/8] demo code that visualises translation and its derivative --- test/diff_pub.py | 27 ---------------------- test/diff_publisher.py | 35 +++++++++++++++++++++++++++++ test/plotter_subsriber.py | 47 +++++++++++++++++++++++++++++++++++++++ test/simstep.py | 18 +++++++++++++++ 4 files changed, 100 insertions(+), 27 deletions(-) delete mode 100644 test/diff_pub.py create mode 100644 test/diff_publisher.py create mode 100644 test/plotter_subsriber.py create mode 100644 test/simstep.py diff --git a/test/diff_pub.py b/test/diff_pub.py deleted file mode 100644 index 40ff663..0000000 --- a/test/diff_pub.py +++ /dev/null @@ -1,27 +0,0 @@ -from ark.node import BaseNode -from itertools import count -from common import z_cfg - - -class PublisherNode(BaseNode): - - def __init__(self): - super().__init__("env", "pub", z_cfg, sim=True) - self.pub = self.create_publisher("diff_sim") - self.rate = self.create_rate(1) # 1 Hz - - def spin(self): - for c in count(): - msg = f"Hello World {c}" - self.pub.publish(msg.encode("utf-8")) - print(f"Published: {msg}") - self.rate.sleep() - - -if __name__ == "__main__": - try: - node = PublisherNode() - node.spin() - except KeyboardInterrupt: - print("Shutting down publisher node.") - node.close() diff --git a/test/diff_publisher.py b/test/diff_publisher.py new file mode 100644 index 0000000..5bb1a04 --- /dev/null +++ b/test/diff_publisher.py @@ -0,0 +1,35 @@ +import math +import time +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +# Lissajous parameters +A, B = 1.0, 1.0 +a, b = 3.0, 2.0 +delta = math.pi / 2 +HZ = 50 +DT = 1.0 / HZ +class DiffPublisherNode(BaseNode): + def __init__(self): + super().__init__("env", "diff_pub", z_cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.vel_pub = self.create_publisher("velocity") + self.rate = self.create_rate(HZ) + def spin(self): + t = 0.0 + while True: + x = A * math.sin(a * t + delta) + y = B * math.sin(b * t) + dx = A * a * math.cos(a * t + delta) + dy = B * b * math.cos(b * t) + self.pos_pub.publish(Translation(x=x, y=y, z=0.0)) + self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) + t += DT + self.rate.sleep() +if __name__ == "__main__": + try: + node = DiffPublisherNode() + node.spin() + except KeyboardInterrupt: + print("Shutting down diff publisher.") + node.close() diff --git a/test/plotter_subsriber.py b/test/plotter_subsriber.py new file mode 100644 index 0000000..2477962 --- /dev/null +++ b/test/plotter_subsriber.py @@ -0,0 +1,47 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, dTranslation +from common import z_cfg +class SubscriberPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.vel_x, self.vel_y = [], [] + self.create_subscriber("position", self.on_position) + self.create_subscriber("velocity", self.on_velocity) + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def on_velocity(self, msg: dTranslation): + self.vel_x.append(msg.x) + self.vel_y.append(msg.y) +def main(): + node = SubscriberPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_vel) = plt.subplots(1, 2, figsize=(10, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-1.5, 1.5) + ax_pos.set_ylim(-1.5, 1.5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_vel.set_title("Velocity (dTranslation)") + ax_vel.set_xlabel("dx") + ax_vel.set_ylabel("dy") + ax_vel.set_xlim(-5, 5) + ax_vel.set_ylim(-5, 5) + ax_vel.set_aspect("equal") + (line_vel,) = ax_vel.plot([], [], "r-") + def update(frame): + line_pos.set_data(node.pos_x, node.pos_y) + line_vel.set_data(node.vel_x, node.vel_y) + return line_pos, line_vel + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/simstep.py b/test/simstep.py new file mode 100644 index 0000000..06da237 --- /dev/null +++ b/test/simstep.py @@ -0,0 +1,18 @@ +from ark.time.simtime import SimTime +from common import z_cfg +import json +import zenoh +import time + +def main(): + z_config = zenoh.Config.from_json5(json.dumps(z_cfg)) + with zenoh.open(z_config) as z: + sim_time = SimTime(z, "clock", 1000) + sim_time.reset() + while True: + current_time = time.time() + print(f"Simulated Time: {current_time:.2f} seconds") + sim_time.tick() + +if __name__ == "__main__": + main() From dc7fbc1231b0ba208b4bbb12b7d3e718530690c0 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Wed, 11 Feb 2026 18:02:27 +0000 Subject: [PATCH 3/8] [untested] basic gradient broadcasting via query --- test/autodiff.py | 54 +++++++++++++++++++++++++++ test/diff_publisher.py | 83 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 3 deletions(-) create mode 100644 test/autodiff.py diff --git a/test/autodiff.py b/test/autodiff.py new file mode 100644 index 0000000..b66b6da --- /dev/null +++ b/test/autodiff.py @@ -0,0 +1,54 @@ +import numpy as np + +class Value: + def __init__(self, val, parents=(), backward=None, name=None): + self.val = np.asarray(val, dtype=float) + self.grad = np.zeros_like(self.val) + self._prev = parents + self._backward = backward or (lambda: None) + self.name = name + def backward(self, grad=None): + if grad is None: + grad = np.ones_like(self.val) + self.grad = self.grad + grad + topo = [] + visited = set() + def build(v): + if v not in visited: + visited.add(v) + for p in v._prev: + build(p) + topo.append(v) + build(self) + for v in reversed(topo): + v._backward() + def __add__(self, other): + out = Value(self.val + other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad + out.grad + out._backward = _backward + return out + def __sub__(self, other): + out = Value(self.val - other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + out.grad + other.grad = other.grad - out.grad + out._backward = _backward + return out + def __mul__(self, other): + out = Value(self.val * other.val, parents=(self, other)) + def _backward(): + self.grad = self.grad + other.val * out.grad + other.grad = other.grad + self.val * out.grad + out._backward = _backward + return out + def __neg__(self): + out = Value(-self.val, parents=(self,)) + def _backward(): + self.grad = self.grad - out.grad + out._backward = _backward + return out +def clear_grads(params): + for p in params: + p.grad = np.zeros_like(p.val) diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 5bb1a04..7d2e1b9 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -1,15 +1,17 @@ import math import time from ark.node import BaseNode -from ark_msgs import Translation, dTranslation +from ark_msgs import Translation, Value from common import z_cfg +import torch + # Lissajous parameters A, B = 1.0, 1.0 a, b = 3.0, 2.0 delta = math.pi / 2 HZ = 50 DT = 1.0 / HZ -class DiffPublisherNode(BaseNode): +class LissajousPublisherNode(BaseNode): def __init__(self): super().__init__("env", "diff_pub", z_cfg, sim=True) self.pos_pub = self.create_publisher("position") @@ -26,9 +28,84 @@ def spin(self): self.vel_pub.publish(dTranslation(x=dx, y=dy, z=0.0)) t += DT self.rate.sleep() + +class LinePublisherNode(BaseNode): + + def __init__(self): + super().__init__("env", "line_pub", z_cfg, sim=True) + self.pos_pub = self.create_publisher("position") + self.rate = self.create_rate(HZ) + self.v = torch.tensor(1.0, requires_grad=True) + self.m = torch.tensor(0.5, requires_grad=True) + self.c = torch.tensor(0.0, requires_grad=True) + self.latest = { + "x": 0.0, + "y": 0.0, + "v_x": 0.0, + "v_y": 0.0, + "m_x": 0.0, + "m_y": 0.0, + "c_x": 0.0, + "c_y": 0.0, + } + self.create_queryable("grad/v/x", self._on_grad_v_x) + self.create_queryable("grad/v/y", self._on_grad_v_y) + self.create_queryable("grad/m/x", self._on_grad_m_x) + self.create_queryable("grad/m/y", self._on_grad_m_y) + self.create_queryable("grad/c/x", self._on_grad_c_x) + self.create_queryable("grad/c/y", self._on_grad_c_y) + + def _on_grad_v_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["v_x"]) + + def _on_grad_v_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["v_y"]) + + def _on_grad_m_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["m_x"]) + + def _on_grad_m_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["m_y"]) + + def _on_grad_c_x(self, _req): + return Value(val=self.latest["x"], grad=self.latest["c_x"]) + + def _on_grad_c_y(self, _req): + return Value(val=self.latest["y"], grad=self.latest["c_y"]) + + def spin(self): + t = 0.0 + while True: + t_val = torch.tensor(t, requires_grad=False) + x = self.v * t_val + y = self.m * x + self.c + self.pos_pub.publish(Translation(x=float(x), y=float(y), z=0.0)) + if self.v.grad is not None: + self.v.grad.zero_() + if self.m.grad is not None: + self.m.grad.zero_() + if self.c.grad is not None: + self.c.grad.zero_() + x.backward(retain_graph=True) + self.latest["v_x"] = float(self.v.grad) + self.latest["m_x"] = float(self.m.grad) + self.latest["c_x"] = float(self.c.grad) + self.v.grad.zero_() + self.m.grad.zero_() + self.c.grad.zero_() + y.backward() + self.latest["v_y"] = float(self.v.grad) + self.latest["m_y"] = float(self.m.grad) + self.latest["c_y"] = float(self.c.grad) + self.latest["x"] = float(x) + self.latest["y"] = float(y) + t += DT + self.rate.sleep() + if __name__ == "__main__": try: - node = DiffPublisherNode() + # node = LissajousPublisherNode() + node = LinePublisherNode() node.spin() except KeyboardInterrupt: print("Shutting down diff publisher.") From 0602f55d3abce18a176e464c69718d3dbb9df1a2 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 12 Feb 2026 10:49:09 +0000 Subject: [PATCH 4/8] basic testing of passing gradients over messages --- src/ark/node.py | 2 +- test/ad_plotter_sub.py | 61 ++++++++++++++++++++++++++++++++++++++++++ test/diff_publisher.py | 13 +++++---- 3 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 test/ad_plotter_sub.py diff --git a/src/ark/node.py b/src/ark/node.py index 6ba6c6e..a1f3103 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -80,7 +80,7 @@ def create_querier(self, channel, timeout=10.0) -> Querier: self._clock, channel, self._data_collector, - timeout, + # timeout, ) querier.core_registration() self._queriers[channel] = querier diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py new file mode 100644 index 0000000..e06bbdc --- /dev/null +++ b/test/ad_plotter_sub.py @@ -0,0 +1,61 @@ +import threading +import matplotlib.pyplot as plt +import matplotlib.animation as animation +from ark.node import BaseNode +from ark_msgs import Translation, Value +from common import z_cfg + +class AutodiffPlotterNode(BaseNode): + def __init__(self): + super().__init__("env", "autodiff_plotter", z_cfg, sim=True) + self.pos_x, self.pos_y = [], [] + self.grad_vx, self.grad_my = [], [] + self.create_subscriber("position", self.on_position) + self.grad_vx_querier = self.create_querier("grad/v/x") + self.grad_my_querier = self.create_querier("grad/m/y") + def on_position(self, msg: Translation): + self.pos_x.append(msg.x) + self.pos_y.append(msg.y) + def fetch_grads(self): + req = Translation(x=0.0, y=0.0, z=0.0) + try: + resp_vx = self.grad_vx_querier.query(req) + if isinstance(resp_vx, Value): + self.grad_vx.append(resp_vx.grad) + except Exception: + pass + try: + resp_my = self.grad_my_querier.query(req) + if isinstance(resp_my, Value): + self.grad_my.append(resp_my.grad) + except Exception: + pass +def main(): + node = AutodiffPlotterNode() + threading.Thread(target=node.spin, daemon=True).start() + fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) + ax_pos.set_title("Position (Translation)") + ax_pos.set_xlabel("x") + ax_pos.set_ylabel("y") + ax_pos.set_xlim(-5, 5) + ax_pos.set_ylim(-5, 5) + ax_pos.set_aspect("equal") + (line_pos,) = ax_pos.plot([], [], "b-") + ax_grad.set_title("Gradients") + ax_grad.set_xlabel("t") + ax_grad.set_ylabel("grad") + (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") + (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") + ax_grad.legend() + def update(frame): + node.fetch_grads() + line_pos.set_data(node.pos_x, node.pos_y) + line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) + line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) + return line_pos, line_grad_vx, line_grad_my + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) + plt.tight_layout() + plt.show() + node.close() +if __name__ == "__main__": + main() diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 7d2e1b9..ff102d5 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -79,7 +79,8 @@ def spin(self): t_val = torch.tensor(t, requires_grad=False) x = self.v * t_val y = self.m * x + self.c - self.pos_pub.publish(Translation(x=float(x), y=float(y), z=0.0)) + self.pos_pub.publish(Translation(x=float(x.detach()), + y=float(y.detach()), z=0.0)) if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: @@ -88,17 +89,15 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - self.latest["m_x"] = float(self.m.grad) - self.latest["c_x"] = float(self.c.grad) + # self.latest["m_x"] = float(self.m.grad) + # self.latest["c_x"] = float(self.c.grad) self.v.grad.zero_() - self.m.grad.zero_() - self.c.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) self.latest["m_y"] = float(self.m.grad) self.latest["c_y"] = float(self.c.grad) - self.latest["x"] = float(x) - self.latest["y"] = float(y) + self.latest["x"] = float(x.detach()) + self.latest["y"] = float(y.detach()) t += DT self.rate.sleep() From 06159f0bf52876bd21176cfe7eb2a30902d435ff Mon Sep 17 00:00:00 2001 From: kamiradi Date: Thu, 12 Feb 2026 12:37:22 +0000 Subject: [PATCH 5/8] queryable not recieving queries --- src/ark/comm/queriable.py | 8 +++++++- src/ark/node.py | 1 + test/ad_plotter_sub.py | 1 + test/diff_publisher.py | 19 +++++++++++++------ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 294346a..06da42f 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -22,18 +22,22 @@ def __init__( super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler self._queryable = self._session.declare_queryable(self._channel, self._on_query) + print(f"Declared queryable on channel: {self._channel}") def core_registration(self): print("..todo: register with ark core..") def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries + print("Received query, processing...") if not self._active: + print("Received query on closed Queryable, ignoring") return try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) + print("Parsing query") raw = bytes(query.value) if query.value is not None else b"" if not raw: return # nothing to do @@ -42,7 +46,8 @@ def _on_query(self, query: zenoh.Query) -> None: req_env.ParseFromString(raw) # Decode request protobuf - req_type = msgs.get(req_env.payload_msg_type) + # req_type = msgs.get(req_env.payload_msg_type) + req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) return @@ -73,4 +78,5 @@ def _on_query(self, query: zenoh.Query) -> None: except Exception: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired + print("Error processing query:") return diff --git a/src/ark/node.py b/src/ark/node.py index a1f3103..205f5a9 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -84,6 +84,7 @@ def create_querier(self, channel, timeout=10.0) -> Querier: ) querier.core_registration() self._queriers[channel] = querier + # print session and channelinfo for debugging return querier def create_queryable(self, channel, handler) -> Queryable: diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index e06bbdc..104948c 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -20,6 +20,7 @@ def fetch_grads(self): req = Translation(x=0.0, y=0.0, z=0.0) try: resp_vx = self.grad_vx_querier.query(req) + print(f"Queried grad_vx: {resp_vx.grad}") if isinstance(resp_vx, Value): self.grad_vx.append(resp_vx.grad) except Exception: diff --git a/test/diff_publisher.py b/test/diff_publisher.py index ff102d5..788d56f 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -48,14 +48,17 @@ def __init__(self): "c_x": 0.0, "c_y": 0.0, } - self.create_queryable("grad/v/x", self._on_grad_v_x) - self.create_queryable("grad/v/y", self._on_grad_v_y) - self.create_queryable("grad/m/x", self._on_grad_m_x) - self.create_queryable("grad/m/y", self._on_grad_m_y) - self.create_queryable("grad/c/x", self._on_grad_c_x) - self.create_queryable("grad/c/y", self._on_grad_c_y) + # declare and store all the queryables for gradients + self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + def _on_grad_v_x(self, _req): + print(f"Received query for grad_v_x, latest") return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_y(self, _req): @@ -65,6 +68,7 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): + print(f"Received query for grad_v_x, latest") return Value(val=self.latest["y"], grad=self.latest["m_y"]) def _on_grad_c_x(self, _req): @@ -89,12 +93,15 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) + # print(f"v grad {self.v.grad.item()}") # self.latest["m_x"] = float(self.m.grad) # self.latest["c_x"] = float(self.c.grad) self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) + # print(f"v grad {self.v.grad.item()}") self.latest["m_y"] = float(self.m.grad) + # print(f"m grad {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) From 5d05ab3206b1fc4189800da920136ac327f7381c Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 14:16:55 +0000 Subject: [PATCH 6/8] simple gradient experiment --- src/ark/comm/queriable.py | 2 +- src/ark/comm/querier.py | 10 +++- src/ark/node.py | 6 ++- test/ad_plotter_sub.py | 96 ++++++++++++++++++++++++++++++++++++--- test/common.py | 18 +++++++- test/common_example.py | 83 +++++++++++++++++++++++++++++++++ test/diff_publisher.py | 71 ++++++++++++++++++++--------- 7 files changed, 252 insertions(+), 34 deletions(-) create mode 100644 test/common_example.py diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 06da42f..961c98b 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -69,7 +69,7 @@ def _on_query(self, query: zenoh.Query) -> None: self._seq_index += 1 resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(resp_env.SerializeToString()) + query.reply(self._channel, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c6d4586..c10d216 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -11,12 +11,16 @@ def __init__( self, node_name: str, session: zenoh.Session, + query_target, clock, channel: str, data_collector: DataCollector | None, ): super().__init__(node_name, session, clock, channel, data_collector) - self._querier = self._session.declare_querier(self._channel) + self._querier = self._session.declare_querier(self._channel, + target=query_target) + print(f"Declared querier on channel: {self._channel}") + self._query_selector = zenoh.Selector(self._channel) def core_registration(self): print("..todo: register with ark core..") @@ -48,7 +52,9 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout) + print(f"Sending query on channel '{self._channel}' with timeout {timeout}s") + replies = self._querier.get(parameters=self._query_selector.parameters, payload=req_env.SerializeToString(), timeout=timeout) + print(f"Received {len(replies)} replies for query on channel '{self._channel}'") for reply in replies: if reply.ok is None: diff --git a/src/ark/node.py b/src/ark/node.py index 205f5a9..24a8df0 100644 --- a/src/ark/node.py +++ b/src/ark/node.py @@ -22,7 +22,8 @@ def __init__( sim: bool = False, collect_data: bool = False, ): - self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + # self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg)) + self._z_cfg = z_cfg self._session = zenoh.open(self._z_cfg) self._env_name = env_name self._node_name = node_name @@ -73,10 +74,11 @@ def create_subscriber(self, channel, callback) -> Subscriber: self._subs[channel] = sub return sub - def create_querier(self, channel, timeout=10.0) -> Querier: + def create_querier(self, channel, target, timeout=10.0) -> Querier: querier = Querier( self._node_name, self._session, + target, self._clock, channel, self._data_collector, diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 104948c..5a67adf 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -3,21 +3,38 @@ import matplotlib.animation as animation from ark.node import BaseNode from ark_msgs import Translation, Value -from common import z_cfg +# from common import connect_cfg, z_cfg +import argparse +import zenoh +import common_example as common + class AutodiffPlotterNode(BaseNode): - def __init__(self): - super().__init__("env", "autodiff_plotter", z_cfg, sim=True) + def __init__(self, cfg, target): + super().__init__("env", "autodiff_plotter", cfg, sim=True) self.pos_x, self.pos_y = [], [] self.grad_vx, self.grad_my = [], [] self.create_subscriber("position", self.on_position) - self.grad_vx_querier = self.create_querier("grad/v/x") - self.grad_my_querier = self.create_querier("grad/m/y") + # self.grad_vx_querier = self.create_querier("grad/v/x", target=target) + # self.grad_my_querier = self.create_querier("grad/m/y", target=target) + self.grad_vx_querier = self._session.declare_querier( + "grad/v/x", + target=target, + timeout=10.0, + ) + self.grad_my_querier = self._session.declare_querier( + "grad/m/y", + target=target, + timeout=10.0 + ) + def on_position(self, msg: Translation): self.pos_x.append(msg.x) self.pos_y.append(msg.y) + def fetch_grads(self): req = Translation(x=0.0, y=0.0, z=0.0) + print("fetching grads") try: resp_vx = self.grad_vx_querier.query(req) print(f"Queried grad_vx: {resp_vx.grad}") @@ -31,8 +48,71 @@ def fetch_grads(self): self.grad_my.append(resp_my.grad) except Exception: pass + + def fetch_grads_exp(self): + try: + resp_vx = self.grad_vx_querier.get() + for resp in resp_vx: + if resp.ok is None: + continue + v_value_str = bytes(resp.ok.payload).decode("utf-8") + v_value = float(v_value_str) + print(f"Queried grad_vx: {v_value}") + self.grad_vx.append(v_value) + except Exception: + pass + try: + resp_my = self.grad_my_querier.get() + for resp in resp_my: + if resp.ok is None: + continue + m_value_str = bytes(resp.ok.payload).decode("utf-8") + m_value = float(m_value_str) + print(f"Queried grad_my: {m_value}") + self.grad_my.append(m_value) + except Exception: + print("Failed to query grad_my") + pass def main(): - node = AutodiffPlotterNode() + parser = argparse.ArgumentParser(description="Autodiff Plotter Node") + common.add_config_arguments(parser) + parser.add_argument( + "--target", + "-t", + dest="target", + choices=["ALL", "BEST_MATCHING", "ALL_COMPLETE", "NONE"], + default="BEST_MATCHING", + type=str, + help="The target queryables of the query.", + ) + parser.add_argument( + "--timeout", + "-o", + dest="timeout", + default=10.0, + type=float, + help="The query timeout", + ) + parser.add_argument( + "--iter", dest="iter", type=int, help="How many gets to perform" + ) + parser.add_argument( + "--add-matching-listener", + default=False, + action="store_true", + help="Add matching listener", + ) + + args = parser.parse_args() + conf = common.get_config_from_args(args) + + target = { + "ALL": zenoh.QueryTarget.ALL, + "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, + "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, + }.get(args.target) + + node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) ax_pos.set_title("Position (Translation)") @@ -45,11 +125,13 @@ def main(): ax_grad.set_title("Gradients") ax_grad.set_xlabel("t") ax_grad.set_ylabel("grad") + ax_grad.set_xlim(-5, 5) + ax_grad.set_ylim(-5, 5) (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") ax_grad.legend() def update(frame): - node.fetch_grads() + node.fetch_grads_exp() line_pos.set_data(node.pos_x, node.pos_y) line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) diff --git a/test/common.py b/test/common.py index b26b213..0a5d7bc 100644 --- a/test/common.py +++ b/test/common.py @@ -1 +1,17 @@ -z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}} +listen_cfg = { + "mode": "peer", + "listen": { + "endpoints": ["tcp/0.0.0.0:7447"]}, +} +connect_cfg = { + "mode": "peer", + "connect": { + "endpoints": ["tcp/127.0.0.1:7447"] + } +} +z_cfg = { + "mode": "peer", + # "connect": { + # "endpoints":["udp/127.0.0.1:7447"] + # } +} diff --git a/test/common_example.py b/test/common_example.py new file mode 100644 index 0000000..0c1eea3 --- /dev/null +++ b/test/common_example.py @@ -0,0 +1,83 @@ +import argparse +import json + +import zenoh + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--mode", + "-m", + dest="mode", + choices=["peer", "client"], + type=str, + help="The zenoh session mode.", + ) + parser.add_argument( + "--connect", + "-e", + dest="connect", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to connect to.", + ) + parser.add_argument( + "--listen", + "-l", + dest="listen", + metavar="ENDPOINT", + action="append", + type=str, + help="Endpoints to listen on.", + ) + parser.add_argument( + "--config", + "-c", + dest="config", + metavar="FILE", + type=str, + help="A configuration file.", + ) + parser.add_argument( + "--no-multicast-scouting", + dest="no_multicast_scouting", + default=False, + action="store_true", + help="Disable multicast scouting.", + ) + parser.add_argument( + "--cfg", + dest="cfg", + metavar="CFG", + default=[], + action="append", + type=str, + help="Allows arbitrary configuration changes as column-separated KEY:VALUE pairs. Where KEY must be a valid config path and VALUE must be a valid JSON5 string that can be deserialized to the expected type for the KEY field. Example: --cfg='transport/unicast/max_links:2'.", + ) + + +def get_config_from_args(args) -> zenoh.Config: + conf = ( + zenoh.Config.from_file(args.config) + if args.config is not None + else zenoh.Config() + ) + if args.mode is not None: + conf.insert_json5("mode", json.dumps(args.mode)) + if args.connect is not None: + conf.insert_json5("connect/endpoints", json.dumps(args.connect)) + if args.listen is not None: + conf.insert_json5("listen/endpoints", json.dumps(args.listen)) + if args.no_multicast_scouting: + conf.insert_json5("scouting/multicast/enabled", json.dumps(False)) + + for c in args.cfg: + try: + [key, value] = c.split(":", 1) + except: + print(f"`--cfg` argument: expected KEY:VALUE pair, got {c}") + raise + conf.insert_json5(key, value) + + return conf diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 788d56f..d1dcf1a 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -2,7 +2,9 @@ import time from ark.node import BaseNode from ark_msgs import Translation, Value -from common import z_cfg +import argparse +# from common import listen_cfg, z_cfg +import common_example as common import torch # Lissajous parameters @@ -13,7 +15,7 @@ DT = 1.0 / HZ class LissajousPublisherNode(BaseNode): def __init__(self): - super().__init__("env", "diff_pub", z_cfg, sim=True) + super().__init__("env", "diff_pub", listen_cfg, sim=True) self.pos_pub = self.create_publisher("position") self.vel_pub = self.create_publisher("velocity") self.rate = self.create_rate(HZ) @@ -31,8 +33,8 @@ def spin(self): class LinePublisherNode(BaseNode): - def __init__(self): - super().__init__("env", "line_pub", z_cfg, sim=True) + def __init__(self, cfg): + super().__init__("env", "line_pub", cfg, sim=True) self.pos_pub = self.create_publisher("position") self.rate = self.create_rate(HZ) self.v = torch.tensor(1.0, requires_grad=True) @@ -49,17 +51,29 @@ def __init__(self): "c_y": 0.0, } # declare and store all the queryables for gradients - self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) - self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) - self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) - self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) - self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) - self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + # self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + # self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + # self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + # self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + # self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + # self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) + self.grad_v_queryable = self._session.declare_queryable("grad/v/x", + self._on_grad_v_x, + complete=False) + self.grad_m_queryable = self._session.declare_queryable("grad/m/y", + self._on_grad_m_y, + complete=False) + # def _on_grad_v_x(self, _req): + # print(f"Received query for grad_v_x, latest") + # return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_x(self, _req): - print(f"Received query for grad_v_x, latest") - return Value(val=self.latest["x"], grad=self.latest["v_x"]) + v_value = self.latest["v_x"] + v_value_str = str(v_value) + payload = v_value_str.encode("utf-8") + _req.reply("grad/v/x", payload) + pass def _on_grad_v_y(self, _req): return Value(val=self.latest["y"], grad=self.latest["v_y"]) @@ -68,8 +82,11 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): - print(f"Received query for grad_v_x, latest") - return Value(val=self.latest["y"], grad=self.latest["m_y"]) + m_value = self.latest["m_y"] + m_value_str = str(m_value) + payload = m_value_str.encode("utf-8") + _req.reply("grad/m/y", payload) + pass def _on_grad_c_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["c_x"]) @@ -93,25 +110,37 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - # print(f"v grad {self.v.grad.item()}") - # self.latest["m_x"] = float(self.m.grad) - # self.latest["c_x"] = float(self.c.grad) + print(f"v grad {self.v.grad.item()}") self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) - # print(f"v grad {self.v.grad.item()}") self.latest["m_y"] = float(self.m.grad) - # print(f"m grad {self.m.grad}") + print(f"m grad {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) + # with self.grad_v_queryable.recv() as query: + # print(f"Received query for grad_v_x, latest") t += DT self.rate.sleep() if __name__ == "__main__": try: - # node = LissajousPublisherNode() - node = LinePublisherNode() + parser = argparse.ArgumentParser( + prog="z_queryable", description="zenoh queryable example" + ) + common.add_config_arguments(parser) + parser.add_argument( + "--complete", + dest="complete", + default=False, + action="store_true", + help="Declare the queryable as complete w.r.t. the key expression.", + ) + args = parser.parse_args() + conf = common.get_config_from_args(args) + + node = LinePublisherNode(conf) node.spin() except KeyboardInterrupt: print("Shutting down diff publisher.") From 0537c41fea7bd4dbd7440789c492a74a75edfc71 Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 18:12:54 +0000 Subject: [PATCH 7/8] gradient querying working with ark querier and queriablew --- src/ark/comm/queriable.py | 20 +++++++++---- src/ark/comm/querier.py | 17 +++++------ test/ad_plotter_sub.py | 62 +++++++++++++-------------------------- test/diff_publisher.py | 50 ++++++++++++------------------- 4 files changed, 61 insertions(+), 88 deletions(-) diff --git a/src/ark/comm/queriable.py b/src/ark/comm/queriable.py index 961c98b..2e55a75 100644 --- a/src/ark/comm/queriable.py +++ b/src/ark/comm/queriable.py @@ -21,7 +21,9 @@ def __init__( ): super().__init__(node_name, session, clock, channel, data_collector) self._handler = handler - self._queryable = self._session.declare_queryable(self._channel, self._on_query) + self._queryable = self._session.declare_queryable(self._channel, + self._on_query, + complete=False) print(f"Declared queryable on channel: {self._channel}") def core_registration(self): @@ -29,7 +31,6 @@ def core_registration(self): def _on_query(self, query: zenoh.Query) -> None: # If we were closed, ignore queries - print("Received query, processing...") if not self._active: print("Received query on closed Queryable, ignoring") return @@ -37,9 +38,9 @@ def _on_query(self, query: zenoh.Query) -> None: try: # Zenoh query may or may not include a payload. # For your use-case, the request is always in query.value (bytes) - print("Parsing query") - raw = bytes(query.value) if query.value is not None else b"" + raw = bytes(query.payload) if query.payload is not None else b"" if not raw: + print("Received query with no payload, ignoring") return # nothing to do req_env = Envelope() @@ -50,6 +51,7 @@ def _on_query(self, query: zenoh.Query) -> None: req_type = msgs.get(req_env.msg_type) if req_type is None: # Unknown message type: ignore (or reply error later) + print(f"Unknown message type '{req_env.msg_type}' in query, ignoring") return req_msg = req_type() @@ -65,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None: resp_env.sent_seq_index = self._seq_index resp_env.src_node_name = self._node_name resp_env.channel = self._channel + resp_env.msg_type = resp_msg.DESCRIPTOR.full_name + resp_env.payload = resp_msg.SerializeToString() self._seq_index += 1 - resp_env = Envelope.pack(self._node_name, self._clock, resp_msg) - query.reply(self._channel, resp_env.SerializeToString()) + with query: + query.reply(query.key_expr, resp_env.SerializeToString()) if self._data_collector: self._data_collector.append(req_env.SerializeToString()) @@ -79,4 +83,8 @@ def _on_query(self, query: zenoh.Query) -> None: # Keep it minimal: don't kill the zenoh callback thread # You can add logging here if desired print("Error processing query:") + # write the traceback to stdout for debugging + import traceback + traceback.print_exc() + return diff --git a/src/ark/comm/querier.py b/src/ark/comm/querier.py index c10d216..88a6e8a 100644 --- a/src/ark/comm/querier.py +++ b/src/ark/comm/querier.py @@ -3,6 +3,7 @@ from google.protobuf.message import Message from ark.data.data_collector import DataCollector from ark.comm.end_point import EndPoint +from ark_msgs.registry import msgs class Querier(EndPoint): @@ -52,20 +53,21 @@ def query( else: raise TypeError("req must be a protobuf Message or bytes") - print(f"Sending query on channel '{self._channel}' with timeout {timeout}s") - replies = self._querier.get(parameters=self._query_selector.parameters, payload=req_env.SerializeToString(), timeout=timeout) - print(f"Received {len(replies)} replies for query on channel '{self._channel}'") + replies = self._querier.get(payload=req_env.SerializeToString()) for reply in replies: if reply.ok is None: continue resp_env = Envelope() - resp_env.ParseFromString(bytes(reply.ok)) + resp_env.ParseFromString(bytes(reply.ok.payload)) resp_env.dst_node_name = self._node_name resp_env.recv_timestamp = self._clock.now() - resp = resp_env.extract_message() + try: + resp = resp_env.extract_message() + except Exception as e: + continue self._seq_index += 1 @@ -75,11 +77,6 @@ def query( return resp - else: - raise TimeoutError( - f"No OK reply received for query on '{self._channel}' within {timeout}s" - ) - def close(self): super().close() self._querier.undeclare() diff --git a/test/ad_plotter_sub.py b/test/ad_plotter_sub.py index 5a67adf..314f23a 100644 --- a/test/ad_plotter_sub.py +++ b/test/ad_plotter_sub.py @@ -3,6 +3,7 @@ import matplotlib.animation as animation from ark.node import BaseNode from ark_msgs import Translation, Value + # from common import connect_cfg, z_cfg import argparse import zenoh @@ -15,65 +16,36 @@ def __init__(self, cfg, target): self.pos_x, self.pos_y = [], [] self.grad_vx, self.grad_my = [], [] self.create_subscriber("position", self.on_position) - # self.grad_vx_querier = self.create_querier("grad/v/x", target=target) - # self.grad_my_querier = self.create_querier("grad/m/y", target=target) - self.grad_vx_querier = self._session.declare_querier( - "grad/v/x", - target=target, - timeout=10.0, - ) - self.grad_my_querier = self._session.declare_querier( - "grad/m/y", - target=target, - timeout=10.0 - ) + self.grad_vx_querier = self.create_querier("grad/v/x", target=target) + self.grad_my_querier = self.create_querier("grad/m/y", target=target) def on_position(self, msg: Translation): self.pos_x.append(msg.x) self.pos_y.append(msg.y) def fetch_grads(self): - req = Translation(x=0.0, y=0.0, z=0.0) - print("fetching grads") + req = Value() try: resp_vx = self.grad_vx_querier.query(req) - print(f"Queried grad_vx: {resp_vx.grad}") if isinstance(resp_vx, Value): + print(f"Received grad_vx: {resp_vx.grad}") self.grad_vx.append(resp_vx.grad) except Exception: pass try: resp_my = self.grad_my_querier.query(req) if isinstance(resp_my, Value): + print(f"Received grad_my: {resp_my.grad}") self.grad_my.append(resp_my.grad) except Exception: pass - def fetch_grads_exp(self): - try: - resp_vx = self.grad_vx_querier.get() - for resp in resp_vx: - if resp.ok is None: - continue - v_value_str = bytes(resp.ok.payload).decode("utf-8") - v_value = float(v_value_str) - print(f"Queried grad_vx: {v_value}") - self.grad_vx.append(v_value) - except Exception: - pass - try: - resp_my = self.grad_my_querier.get() - for resp in resp_my: - if resp.ok is None: - continue - m_value_str = bytes(resp.ok.payload).decode("utf-8") - m_value = float(m_value_str) - print(f"Queried grad_my: {m_value}") - self.grad_my.append(m_value) - except Exception: - print("Failed to query grad_my") - pass + def main(): + + # These are a few zenoh config related arguments that were taken from the + # examples, keeping them there until we have a better way to manage configs + # across examples parser = argparse.ArgumentParser(description="Autodiff Plotter Node") common.add_config_arguments(parser) parser.add_argument( @@ -102,18 +74,22 @@ def main(): action="store_true", help="Add matching listener", ) - + args = parser.parse_args() conf = common.get_config_from_args(args) + # These were required for the querier and queryable to find each other. target = { "ALL": zenoh.QueryTarget.ALL, "BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING, "ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE, }.get(args.target) + # Main subcription and querying loop node = AutodiffPlotterNode(conf, target) threading.Thread(target=node.spin, daemon=True).start() + + # Plotting trajectory and gradients fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5)) ax_pos.set_title("Position (Translation)") ax_pos.set_xlabel("x") @@ -130,15 +106,19 @@ def main(): (line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv") (line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm") ax_grad.legend() + def update(frame): - node.fetch_grads_exp() + node.fetch_grads() line_pos.set_data(node.pos_x, node.pos_y) line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx) line_grad_my.set_data(range(len(node.grad_my)), node.grad_my) return line_pos, line_grad_vx, line_grad_my + ani = animation.FuncAnimation(fig, update, interval=50, blit=True) plt.tight_layout() plt.show() node.close() + + if __name__ == "__main__": main() diff --git a/test/diff_publisher.py b/test/diff_publisher.py index d1dcf1a..153ff92 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -3,6 +3,7 @@ from ark.node import BaseNode from ark_msgs import Translation, Value import argparse + # from common import listen_cfg, z_cfg import common_example as common import torch @@ -13,12 +14,15 @@ delta = math.pi / 2 HZ = 50 DT = 1.0 / HZ + + class LissajousPublisherNode(BaseNode): def __init__(self): super().__init__("env", "diff_pub", listen_cfg, sim=True) self.pos_pub = self.create_publisher("position") self.vel_pub = self.create_publisher("velocity") self.rate = self.create_rate(HZ) + def spin(self): t = 0.0 while True: @@ -31,6 +35,7 @@ def spin(self): t += DT self.rate.sleep() + class LinePublisherNode(BaseNode): def __init__(self, cfg): @@ -51,29 +56,15 @@ def __init__(self, cfg): "c_y": 0.0, } # declare and store all the queryables for gradients - # self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) - # self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) - # self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) - # self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) - # self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) - # self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) - self.grad_v_queryable = self._session.declare_queryable("grad/v/x", - self._on_grad_v_x, - complete=False) - self.grad_m_queryable = self._session.declare_queryable("grad/m/y", - self._on_grad_m_y, - complete=False) - - # def _on_grad_v_x(self, _req): - # print(f"Received query for grad_v_x, latest") - # return Value(val=self.latest["x"], grad=self.latest["v_x"]) + self.grad_v_x_q = self.create_queryable("grad/v/x", self._on_grad_v_x) + self.grad_v_y_q = self.create_queryable("grad/v/y", self._on_grad_v_y) + self.grad_m_x_q = self.create_queryable("grad/m/x", self._on_grad_m_x) + self.grad_m_y_q = self.create_queryable("grad/m/y", self._on_grad_m_y) + self.grad_c_x_q = self.create_queryable("grad/c/x", self._on_grad_c_x) + self.grad_c_y_q = self.create_queryable("grad/c/y", self._on_grad_c_y) def _on_grad_v_x(self, _req): - v_value = self.latest["v_x"] - v_value_str = str(v_value) - payload = v_value_str.encode("utf-8") - _req.reply("grad/v/x", payload) - pass + return Value(val=self.latest["x"], grad=self.latest["v_x"]) def _on_grad_v_y(self, _req): return Value(val=self.latest["y"], grad=self.latest["v_y"]) @@ -82,10 +73,7 @@ def _on_grad_m_x(self, _req): return Value(val=self.latest["x"], grad=self.latest["m_x"]) def _on_grad_m_y(self, _req): - m_value = self.latest["m_y"] - m_value_str = str(m_value) - payload = m_value_str.encode("utf-8") - _req.reply("grad/m/y", payload) + return Value(val=self.latest["y"], grad=self.latest["m_y"]) pass def _on_grad_c_x(self, _req): @@ -100,8 +88,9 @@ def spin(self): t_val = torch.tensor(t, requires_grad=False) x = self.v * t_val y = self.m * x + self.c - self.pos_pub.publish(Translation(x=float(x.detach()), - y=float(y.detach()), z=0.0)) + self.pos_pub.publish( + Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) + ) if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: @@ -110,20 +99,19 @@ def spin(self): self.c.grad.zero_() x.backward(retain_graph=True) self.latest["v_x"] = float(self.v.grad) - print(f"v grad {self.v.grad.item()}") + print(f"Grad v_x {self.v.grad.item()}") self.v.grad.zero_() y.backward() self.latest["v_y"] = float(self.v.grad) self.latest["m_y"] = float(self.m.grad) - print(f"m grad {self.m.grad}") + print(f"Grad m_y {self.m.grad}") self.latest["c_y"] = float(self.c.grad) self.latest["x"] = float(x.detach()) self.latest["y"] = float(y.detach()) - # with self.grad_v_queryable.recv() as query: - # print(f"Received query for grad_v_x, latest") t += DT self.rate.sleep() + if __name__ == "__main__": try: parser = argparse.ArgumentParser( From d76495e5732b19cc9ff5fbd4a398913edc5eeb6e Mon Sep 17 00:00:00 2001 From: kamiradi Date: Fri, 13 Feb 2026 19:11:52 +0000 Subject: [PATCH 8/8] adds readme to run gradient experiment --- test/diff_publisher.py | 7 +++++++ test/gradient_exp.md | 47 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 test/gradient_exp.md diff --git a/test/diff_publisher.py b/test/diff_publisher.py index 153ff92..aa74ee8 100644 --- a/test/diff_publisher.py +++ b/test/diff_publisher.py @@ -86,11 +86,18 @@ def spin(self): t = 0.0 while True: t_val = torch.tensor(t, requires_grad=False) + + # Computation graph + # line equation: y = m * x + c, where x = v * t x = self.v * t_val y = self.m * x + self.c + + # publish position self.pos_pub.publish( Translation(x=float(x.detach()), y=float(y.detach()), z=0.0) ) + + # compute gradients if self.v.grad is not None: self.v.grad.zero_() if self.m.grad is not None: diff --git a/test/gradient_exp.md b/test/gradient_exp.md new file mode 100644 index 0000000..54387b0 --- /dev/null +++ b/test/gradient_exp.md @@ -0,0 +1,47 @@ +# Gradient Experiment + +Demonstrates differentiable simulation using ark framework. A `LinePublisherNode` publishes position on a line (`y = m*x + c`, `x = v*t`) along with autograd gradients (dx/dv, dy/dm, dy/dc), and an `AutodiffPlotterNode` subscribes to position and queries gradients in real time. + +## Prerequisites + +- Install ark framework and dependencies (`zenoh`, `torch`, `matplotlib`, `ark_msgs`) +- Run all commands from the `test/` directory + +## Running the Experiment + +Open three separate terminals. All commands are run from the `test/` directory. + +### Shell 1 — Sim Clock + +Drives simulated time for all sim-enabled nodes. + +```bash +cd test +python simstep.py +``` + +### Shell 2 — Diff Publisher + +Publishes position (`Translation`) and serves gradient queryables (`grad/v/x`, `grad/m/y`, etc.). + +```bash +cd test +python diff_publisher.py +``` + +### Shell 3 — Autodiff Plotter + +Subscribes to position and queries gradients, then plots both in real time. + +```bash +cd test +python ad_plotter_sub.py +``` + +## What to Expect + +- **Shell 1** prints the simulated time advancing each tick. +- **Shell 2** prints computed gradients (`Grad v_x`, `Grad m_y`) each step. +- **Shell 3** opens a matplotlib window with two plots: + - **Left**: Position trajectory (x vs y). + - **Right**: Gradients over time (dx/dv in green, dy/dm in magenta).