Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/ark/comm/queriable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,37 @@ def __init__(
):
super().__init__(node_name, session, clock, channel, data_collector)
self._handler = handler
self._queryable = self._session.declare_queryable(self._channel, self._on_query)
self._queryable = self._session.declare_queryable(self._channel,
self._on_query,
complete=False)
print(f"Declared queryable on channel: {self._channel}")

def core_registration(self):
print("..todo: register with ark core..")

def _on_query(self, query: zenoh.Query) -> None:
# If we were closed, ignore queries
if not self._active:
print("Received query on closed Queryable, ignoring")
return

try:
# Zenoh query may or may not include a payload.
# For your use-case, the request is always in query.value (bytes)
raw = bytes(query.value) if query.value is not None else b""
raw = bytes(query.payload) if query.payload is not None else b""
if not raw:
print("Received query with no payload, ignoring")
return # nothing to do

req_env = Envelope()
req_env.ParseFromString(raw)

# Decode request protobuf
req_type = msgs.get(req_env.payload_msg_type)
# req_type = msgs.get(req_env.payload_msg_type)
req_type = msgs.get(req_env.msg_type)
if req_type is None:
# Unknown message type: ignore (or reply error later)
print(f"Unknown message type '{req_env.msg_type}' in query, ignoring")
return

req_msg = req_type()
Expand All @@ -60,11 +67,13 @@ def _on_query(self, query: zenoh.Query) -> None:
resp_env.sent_seq_index = self._seq_index
resp_env.src_node_name = self._node_name
resp_env.channel = self._channel
resp_env.msg_type = resp_msg.DESCRIPTOR.full_name
resp_env.payload = resp_msg.SerializeToString()

self._seq_index += 1

resp_env = Envelope.pack(self._node_name, self._clock, resp_msg)
query.reply(resp_env.SerializeToString())
with query:
query.reply(query.key_expr, resp_env.SerializeToString())

if self._data_collector:
self._data_collector.append(req_env.SerializeToString())
Expand All @@ -73,4 +82,9 @@ def _on_query(self, query: zenoh.Query) -> None:
except Exception:
# Keep it minimal: don't kill the zenoh callback thread
# You can add logging here if desired
print("Error processing query:")
# write the traceback to stdout for debugging
import traceback
traceback.print_exc()

return
21 changes: 12 additions & 9 deletions src/ark/comm/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from google.protobuf.message import Message
from ark.data.data_collector import DataCollector
from ark.comm.end_point import EndPoint
from ark_msgs.registry import msgs


class Querier(EndPoint):
Expand All @@ -11,12 +12,16 @@ def __init__(
self,
node_name: str,
session: zenoh.Session,
query_target,
clock,
channel: str,
data_collector: DataCollector | None,
):
super().__init__(node_name, session, clock, channel, data_collector)
self._querier = self._session.declare_querier(self._channel)
self._querier = self._session.declare_querier(self._channel,
target=query_target)
print(f"Declared querier on channel: {self._channel}")
self._query_selector = zenoh.Selector(self._channel)

def core_registration(self):
print("..todo: register with ark core..")
Expand Down Expand Up @@ -48,18 +53,21 @@ def query(
else:
raise TypeError("req must be a protobuf Message or bytes")

replies = self._querier.get(value=req_env.SerializeToString(), timeout=timeout)
replies = self._querier.get(payload=req_env.SerializeToString())

for reply in replies:
if reply.ok is None:
continue

resp_env = Envelope()
resp_env.ParseFromString(bytes(reply.ok))
resp_env.ParseFromString(bytes(reply.ok.payload))
resp_env.dst_node_name = self._node_name
resp_env.recv_timestamp = self._clock.now()

resp = resp_env.extract_message()
try:
resp = resp_env.extract_message()
except Exception as e:
continue

self._seq_index += 1

Expand All @@ -69,11 +77,6 @@ def query(

return resp

else:
raise TimeoutError(
f"No OK reply received for query on '{self._channel}' within {timeout}s"
)

def close(self):
super().close()
self._querier.undeclare()
9 changes: 6 additions & 3 deletions src/ark/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(
sim: bool = False,
collect_data: bool = False,
):
self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
# self._z_cfg = zenoh.Config.from_json5(json.dumps(z_cfg))
self._z_cfg = z_cfg
self._session = zenoh.open(self._z_cfg)
self._env_name = env_name
self._node_name = node_name
Expand Down Expand Up @@ -73,17 +74,19 @@ def create_subscriber(self, channel, callback) -> Subscriber:
self._subs[channel] = sub
return sub

def create_querier(self, channel, timeout=10.0) -> Querier:
def create_querier(self, channel, target, timeout=10.0) -> Querier:
querier = Querier(
self._node_name,
self._session,
target,
self._clock,
channel,
self._data_collector,
timeout,
# timeout,
)
querier.core_registration()
self._queriers[channel] = querier
# print session and channelinfo for debugging
return querier

def create_queryable(self, channel, handler) -> Queryable:
Expand Down
124 changes: 124 additions & 0 deletions test/ad_plotter_sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import threading
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from ark.node import BaseNode
from ark_msgs import Translation, Value

# from common import connect_cfg, z_cfg
import argparse
import zenoh
import common_example as common


class AutodiffPlotterNode(BaseNode):
def __init__(self, cfg, target):
super().__init__("env", "autodiff_plotter", cfg, sim=True)
self.pos_x, self.pos_y = [], []
self.grad_vx, self.grad_my = [], []
self.create_subscriber("position", self.on_position)
self.grad_vx_querier = self.create_querier("grad/v/x", target=target)
self.grad_my_querier = self.create_querier("grad/m/y", target=target)

def on_position(self, msg: Translation):
self.pos_x.append(msg.x)
self.pos_y.append(msg.y)

def fetch_grads(self):
req = Value()
try:
resp_vx = self.grad_vx_querier.query(req)
if isinstance(resp_vx, Value):
print(f"Received grad_vx: {resp_vx.grad}")
self.grad_vx.append(resp_vx.grad)
except Exception:
pass
try:
resp_my = self.grad_my_querier.query(req)
if isinstance(resp_my, Value):
print(f"Received grad_my: {resp_my.grad}")
self.grad_my.append(resp_my.grad)
except Exception:
pass


def main():

# These are a few zenoh config related arguments that were taken from the
# examples, keeping them there until we have a better way to manage configs
# across examples
parser = argparse.ArgumentParser(description="Autodiff Plotter Node")
common.add_config_arguments(parser)
parser.add_argument(
"--target",
"-t",
dest="target",
choices=["ALL", "BEST_MATCHING", "ALL_COMPLETE", "NONE"],
default="BEST_MATCHING",
type=str,
help="The target queryables of the query.",
)
parser.add_argument(
"--timeout",
"-o",
dest="timeout",
default=10.0,
type=float,
help="The query timeout",
)
parser.add_argument(
"--iter", dest="iter", type=int, help="How many gets to perform"
)
parser.add_argument(
"--add-matching-listener",
default=False,
action="store_true",
help="Add matching listener",
)

args = parser.parse_args()
conf = common.get_config_from_args(args)

# These were required for the querier and queryable to find each other.
target = {
"ALL": zenoh.QueryTarget.ALL,
"BEST_MATCHING": zenoh.QueryTarget.BEST_MATCHING,
"ALL_COMPLETE": zenoh.QueryTarget.ALL_COMPLETE,
}.get(args.target)

# Main subcription and querying loop
node = AutodiffPlotterNode(conf, target)
threading.Thread(target=node.spin, daemon=True).start()

# Plotting trajectory and gradients
fig, (ax_pos, ax_grad) = plt.subplots(1, 2, figsize=(12, 5))
ax_pos.set_title("Position (Translation)")
ax_pos.set_xlabel("x")
ax_pos.set_ylabel("y")
ax_pos.set_xlim(-5, 5)
ax_pos.set_ylim(-5, 5)
ax_pos.set_aspect("equal")
(line_pos,) = ax_pos.plot([], [], "b-")
ax_grad.set_title("Gradients")
ax_grad.set_xlabel("t")
ax_grad.set_ylabel("grad")
ax_grad.set_xlim(-5, 5)
ax_grad.set_ylim(-5, 5)
(line_grad_vx,) = ax_grad.plot([], [], "g-", label="dx/dv")
(line_grad_my,) = ax_grad.plot([], [], "m-", label="dy/dm")
ax_grad.legend()

def update(frame):
node.fetch_grads()
line_pos.set_data(node.pos_x, node.pos_y)
line_grad_vx.set_data(range(len(node.grad_vx)), node.grad_vx)
line_grad_my.set_data(range(len(node.grad_my)), node.grad_my)
return line_pos, line_grad_vx, line_grad_my

ani = animation.FuncAnimation(fig, update, interval=50, blit=True)
plt.tight_layout()
plt.show()
node.close()


if __name__ == "__main__":
main()
54 changes: 54 additions & 0 deletions test/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

class Value:
def __init__(self, val, parents=(), backward=None, name=None):
self.val = np.asarray(val, dtype=float)
self.grad = np.zeros_like(self.val)
self._prev = parents
self._backward = backward or (lambda: None)
self.name = name
def backward(self, grad=None):
if grad is None:
grad = np.ones_like(self.val)
self.grad = self.grad + grad
topo = []
visited = set()
def build(v):
if v not in visited:
visited.add(v)
for p in v._prev:
build(p)
topo.append(v)
build(self)
for v in reversed(topo):
v._backward()
def __add__(self, other):
out = Value(self.val + other.val, parents=(self, other))
def _backward():
self.grad = self.grad + out.grad
other.grad = other.grad + out.grad
out._backward = _backward
return out
def __sub__(self, other):
out = Value(self.val - other.val, parents=(self, other))
def _backward():
self.grad = self.grad + out.grad
other.grad = other.grad - out.grad
out._backward = _backward
return out
def __mul__(self, other):
out = Value(self.val * other.val, parents=(self, other))
def _backward():
self.grad = self.grad + other.val * out.grad
other.grad = other.grad + self.val * out.grad
out._backward = _backward
return out
def __neg__(self):
out = Value(-self.val, parents=(self,))
def _backward():
self.grad = self.grad - out.grad
out._backward = _backward
return out
def clear_grads(params):
for p in params:
p.grad = np.zeros_like(p.val)
18 changes: 17 additions & 1 deletion test/common.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
z_cfg = {"mode": "peer", "connect": {"endpoints": ["udp/127.0.0.1:7447"]}}
listen_cfg = {
"mode": "peer",
"listen": {
"endpoints": ["tcp/0.0.0.0:7447"]},
}
connect_cfg = {
"mode": "peer",
"connect": {
"endpoints": ["tcp/127.0.0.1:7447"]
}
}
z_cfg = {
"mode": "peer",
# "connect": {
# "endpoints":["udp/127.0.0.1:7447"]
# }
}
Loading