Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/ypywidgets/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import comm
from pycrdt import (
Awareness,
Doc,
Text,
TransactionEvent,
Expand All @@ -10,6 +11,7 @@
create_sync_message,
create_update_message,
handle_sync_message,
read_message,
)

from .widget import Widget
Expand Down Expand Up @@ -48,10 +50,15 @@ def __init__(
) -> None:
self._ydoc = ydoc
self._comm = comm
self._awareness = Awareness(ydoc)
msg = create_sync_message(ydoc)
self._comm.send(buffers=[msg])
self._comm.on_msg(self._receive)

@property
def awareness(self) -> Awareness:
return self._awareness

def _receive(self, msg):
message = bytes(msg["buffers"][0])
match message[0]:
Expand All @@ -61,6 +68,10 @@ def _receive(self, msg):
self._comm.send(buffers=[reply])
if message[1] == YSyncMessageType.SYNC_STEP2:
self._ydoc.observe(self._send)
case YMessageType.AWARENESS:
# Same as pycrdt.websocket.yroom: strip Y message kind, decode body.
update = read_message(message[1:])
self._awareness.apply_awareness_update(update, None)

def _send(self, event: TransactionEvent):
update = event.update
Expand All @@ -86,7 +97,11 @@ def __init__(
create_ydoc=not ydoc,
)
self._comm = create_widget_comm(comm_data, comm_metadata, comm_id)
CommProvider(self.ydoc, self._comm)
self._comm_provider = CommProvider(self.ydoc, self._comm)

@property
def awareness(self) -> Awareness:
return self._comm_provider.awareness

def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover
plaintext = repr(self)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_comm_awareness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import pytest
from pycrdt import Awareness, Doc, YMessageType, create_awareness_message
from ypywidgets.comm import CommWidget

pytestmark = pytest.mark.anyio


async def test_comm_provider_applies_awareness_message(synced_widgets, context):
async with context:
local_widget = await synced_widgets.get_local_widget()
remote_awareness = Awareness(Doc())
remote_awareness.set_local_state({"role": "remote"})
payload = remote_awareness.encode_awareness_update([remote_awareness.client_id])
message = create_awareness_message(payload)

assert message[0] == YMessageType.AWARENESS

local_widget._comm_provider._receive({"buffers": [message]})

remote_state = local_widget.awareness.states.get(remote_awareness.client_id)
assert remote_state is not None
assert remote_state.get("role") == "remote"


async def test_comm_widget_exposes_provider_awareness():
widget = CommWidget()
assert widget.awareness is widget._comm_provider.awareness


async def test_comm_widget_awareness_observe_and_unobserve():
widget = CommWidget()

events: list[str] = []
sub_id = widget.awareness.observe(lambda topic, _: events.append(topic))

widget.awareness.set_local_state({"ping": 1})
assert events

widget.awareness.unobserve(sub_id)
events.clear()
widget.awareness.set_local_state({"ping": 2})
assert events == []