Skip to content

Commit 3e0beab

Browse files
authored
Merge pull request #493 from dwskoog/map_async
Make async functions mappable
2 parents 8c73290 + ea57a59 commit 3e0beab

File tree

5 files changed

+156
-2
lines changed

5 files changed

+156
-2
lines changed

docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Stream
2222
filter
2323
flatten
2424
map
25+
map_async
2526
partition
2627
rate_limit
2728
scatter

docs/source/async.rst

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ This would also work with async-await syntax in Python 3
7272

7373
.. code-block:: python
7474
75+
import asyncio
7576
from streamz import Stream
76-
from tornado.ioloop import IOLoop
7777
7878
async def f():
7979
source = Stream(asynchronous=True) # tell the stream we're working asynchronously
@@ -82,7 +82,28 @@ This would also work with async-await syntax in Python 3
8282
for x in range(10):
8383
await source.emit(x)
8484
85-
IOLoop().run_sync(f)
85+
asyncio.run(f())
86+
87+
When working asynchronously, we can also map asynchronous functions.
88+
89+
.. code-block:: python
90+
91+
async def increment_async(x):
92+
""" A "long-running" increment function
93+
94+
Simulates a function that does real asyncio work.
95+
"""
96+
await asyncio.sleep(0.1)
97+
return x + 1
98+
99+
async def f_inc():
100+
source = Stream(asynchronous=True) # tell the stream we're working asynchronously
101+
source.map_async(increment_async).rate_limit(0.500).sink(write)
102+
103+
for x in range(10):
104+
await source.emit(x)
105+
106+
asyncio.run(f_inc())
86107
87108
88109
Event Loop on a Separate Thread

streamz/core.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,86 @@ def update(self, x, who=None, metadata=None):
718718
return self._emit(result, metadata=metadata)
719719

720720

721+
@Stream.register_api()
722+
class map_async(Stream):
723+
""" Apply an async function to every element in the stream, preserving order
724+
even when evaluating multiple inputs in parallel.
725+
726+
Parameters
727+
----------
728+
func: async callable
729+
*args :
730+
The arguments to pass to the function.
731+
parallelism:
732+
The maximum number of parallel Tasks for evaluating func, default value is 1
733+
**kwargs:
734+
Keyword arguments to pass to func
735+
736+
Examples
737+
--------
738+
>>> async def mult(x, factor=1):
739+
... return factor*x
740+
>>> async def run():
741+
... source = Stream(asynchronous=True)
742+
... source.map_async(mult, factor=2).sink(print)
743+
... for i in range(5):
744+
... await source.emit(i)
745+
>>> asyncio.run(run())
746+
0
747+
2
748+
4
749+
6
750+
8
751+
"""
752+
def __init__(self, upstream, func, *args, parallelism=1, **kwargs):
753+
self.func = func
754+
stream_name = kwargs.pop('stream_name', None)
755+
self.kwargs = kwargs
756+
self.args = args
757+
self.work_queue = asyncio.Queue(maxsize=parallelism)
758+
759+
Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True)
760+
self.work_task = self._create_task(self.work_callback())
761+
762+
def update(self, x, who=None, metadata=None):
763+
return self._create_task(self._insert_job(x, metadata))
764+
765+
def _create_task(self, coro):
766+
if gen.is_future(coro):
767+
return coro
768+
return self.loop.asyncio_loop.create_task(coro)
769+
770+
async def work_callback(self):
771+
while True:
772+
try:
773+
task, metadata = await self.work_queue.get()
774+
self.work_queue.task_done()
775+
result = await task
776+
except Exception as e:
777+
logger.exception(e)
778+
raise
779+
else:
780+
results = self._emit(result, metadata=metadata)
781+
if results:
782+
await asyncio.gather(*results)
783+
self._release_refs(metadata)
784+
785+
async def _wait_for_work_slot(self):
786+
while self.work_queue.full():
787+
await asyncio.sleep(0)
788+
789+
async def _insert_job(self, x, metadata):
790+
try:
791+
await self._wait_for_work_slot()
792+
coro = self.func(x, *self.args, **self.kwargs)
793+
task = self._create_task(coro)
794+
await self.work_queue.put((task, metadata))
795+
self._retain_refs(metadata)
796+
except Exception as e:
797+
logger.exception(e)
798+
raise
799+
800+
721801
@Stream.register_api()
722802
class starmap(Stream):
723803
""" Apply a function to every element in the stream, splayed out

streamz/dataframe/tests/test_dataframes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dask.dataframe.utils import assert_eq
99
import numpy as np
1010
import pandas as pd
11+
from flaky import flaky
1112
from tornado import gen
1213

1314
from streamz import Stream
@@ -570,6 +571,7 @@ def test_cumulative_aggregations(op, getter, stream):
570571
assert_eq(pd.concat(L), expected)
571572

572573

574+
@flaky(max_runs=3, min_passes=1)
573575
@gen_test()
574576
def test_gc():
575577
sdf = sd.Random(freq='5ms', interval='100ms')

streamz/tests/test_core.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,56 @@ def add(x=0, y=0):
126126
assert L[0] == 11
127127

128128

129+
@gen_test()
130+
def test_map_async_tornado():
131+
@gen.coroutine
132+
def add_tor(x=0, y=0):
133+
return x + y
134+
135+
async def add_native(x=0, y=0):
136+
await asyncio.sleep(0.1)
137+
return x + y
138+
139+
source = Stream(asynchronous=True)
140+
L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).buffer(1).sink_to_list()
141+
142+
start = time()
143+
yield source.emit(0)
144+
yield source.emit(1)
145+
yield source.emit(2)
146+
147+
def fail_func():
148+
assert L == [3, 4, 5]
149+
150+
yield await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func)
151+
assert (time() - start) == pytest.approx(0.1, abs=4e-3)
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_map_async():
156+
@gen.coroutine
157+
def add_tor(x=0, y=0):
158+
return x + y
159+
160+
async def add_native(x=0, y=0):
161+
await asyncio.sleep(0.1)
162+
return x + y
163+
164+
source = Stream(asynchronous=True)
165+
L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).sink_to_list()
166+
167+
start = time()
168+
await source.emit(0)
169+
await source.emit(1)
170+
await source.emit(2)
171+
172+
def fail_func():
173+
assert L == [3, 4, 5]
174+
175+
await await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func)
176+
assert (time() - start) == pytest.approx(0.1, abs=4e-3)
177+
178+
129179
def test_map_args():
130180
source = Stream()
131181
L = source.map(operator.add, 10).sink_to_list()

0 commit comments

Comments
 (0)