diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index d7b9f0e5..e22e8b86 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -477,6 +477,11 @@ class Connection //! ThreadMap.makeThread) used to service requests to clients. ::capnp::CapabilityServerSet m_threads; + //! Thread pool populated by ThreadMap.makePool(). When a request arrives + //! with no context.thread set, PassField round-robins across these threads. + std::vector m_thread_pool; + size_t m_thread_pool_index{0}; + //! Canceler for canceling promises that we want to discard when the //! connection is destroyed. This is used to interrupt method calls that are //! still executing at time of disconnection. diff --git a/include/mp/proxy.capnp b/include/mp/proxy.capnp index 386f8f7a..8549d2c2 100644 --- a/include/mp/proxy.capnp +++ b/include/mp/proxy.capnp @@ -45,6 +45,10 @@ interface ThreadMap $count(0) { # execute on. Clients create and name threads and pass the thread handle as # a call parameter. makeThread @0 (name :Text) -> (result :Thread); + # Pre-allocate a pool of server threads for implicit dispatch. When a + # request arrives with no context.thread set, the server dispatches it + # through this pool via a shared work queue. + makePool @1 (name :Text, count :UInt32) -> (); } interface Thread { diff --git a/include/mp/type-context.h b/include/mp/type-context.h index 46952f49..219cafff 100644 --- a/include/mp/type-context.h +++ b/include/mp/type-context.h @@ -201,20 +201,38 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn& const auto& params = server_context.call_context.getParams(); Context::Reader context_arg = Accessor::get(params); auto thread_client = context_arg.getThread(); - auto result = server.m_context.connection->m_threads.getLocalServer(thread_client) - .then([&loop, invoke = kj::mv(invoke), req](const kj::Maybe& perhaps) mutable { - // Assuming the thread object is found, pass it a pointer to the - // `invoke` lambda above which will invoke the function on that - // thread. + auto* connection = server.m_context.connection; + auto result = connection->m_threads.getLocalServer(thread_client) + .then([&loop, invoke = kj::mv(invoke), req, connection](const kj::Maybe& perhaps) mutable { + // If the client specified a thread, dispatch to it directly. KJ_IF_MAYBE (thread_server, perhaps) { auto& thread = static_cast&>(*thread_server); MP_LOG(loop, Log::Debug) << "IPC server post request #" << req << " {" << thread.m_thread_context.thread_name << "}"; return thread.template post(std::move(invoke)); } else { - MP_LOG(loop, Log::Error) - << "IPC server error request #" << req << ", missing thread to execute request"; - throw std::runtime_error("invalid thread handle"); + // No thread specified — fall back to the connection's thread + // pool (populated by ThreadMap.makePool). Error if no pool. + auto& pool = connection->m_thread_pool; + if (pool.empty()) { + MP_LOG(loop, Log::Error) + << "IPC server error request #" << req << ", no thread specified and no pool configured"; + throw std::runtime_error("no thread specified and no pool configured"); + } + size_t idx = connection->m_thread_pool_index++ % pool.size(); + return connection->m_threads.getLocalServer(pool[idx]) + .then([&loop, invoke = kj::mv(invoke), req](const kj::Maybe& pool_perhaps) mutable { + KJ_IF_MAYBE (pt, pool_perhaps) { + auto& pool_thread = static_cast&>(*pt); + MP_LOG(loop, Log::Debug) + << "IPC server post request #" << req << " {" << pool_thread.m_thread_context.thread_name << "}"; + return pool_thread.template post(std::move(invoke)); + } else { + MP_LOG(loop, Log::Error) + << "IPC server error request #" << req << ", pool thread not found"; + throw std::runtime_error("pool thread not found"); + } + }); } }); // Use connection m_canceler object to cancel the result promise if the diff --git a/include/mp/type-threadmap.h b/include/mp/type-threadmap.h index 3005d9de..c38c2ac4 100644 --- a/include/mp/type-threadmap.h +++ b/include/mp/type-threadmap.h @@ -14,6 +14,7 @@ struct ProxyServer final : public virtual ThreadMap::Server public: ProxyServer(Connection& connection); kj::Promise makeThread(MakeThreadContext context) override; + kj::Promise makePool(MakePoolContext context) override; Connection& m_connection; }; diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index 963050c3..b941f958 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include namespace mp { @@ -415,6 +417,31 @@ kj::Promise ProxyServer::getName(GetNameContext context) ProxyServer::ProxyServer(Connection& connection) : m_connection(connection) {} +kj::Promise ProxyServer::makePool(MakePoolContext context) +{ + if (!m_connection.m_thread_pool.empty()) { + throw std::runtime_error("makePool called on connection with existing pool"); + } + EventLoop& loop{*m_connection.m_loop}; + const auto& params = context.getParams(); + const std::string pool_name = params.getName(); + const uint32_t count = params.getCount(); + for (uint32_t i = 0; i < count; ++i) { + const std::string thread_name = pool_name + "/pool/" + std::to_string(i); + std::promise thread_context; + std::thread thread([&loop, &thread_context, thread_name]() { + g_thread_context.thread_name = ThreadName(loop.m_exe_name) + " (from " + thread_name + ")"; + g_thread_context.waiter = std::make_unique(); + Lock lock(g_thread_context.waiter->m_mutex); + thread_context.set_value(&g_thread_context); + g_thread_context.waiter->wait(lock, [] { return !g_thread_context.waiter; }); + }); + auto thread_server = kj::heap>(m_connection, *thread_context.get_future().get(), std::move(thread)); + m_connection.m_thread_pool.push_back(m_connection.m_threads.add(kj::mv(thread_server))); + } + return kj::READY_NOW; +} + kj::Promise ProxyServer::makeThread(MakeThreadContext context) { EventLoop& loop{*m_connection.m_loop}; diff --git a/test/mp/test/test.cpp b/test/mp/test/test.cpp index d91edb40..edf07446 100644 --- a/test/mp/test/test.cpp +++ b/test/mp/test/test.cpp @@ -18,8 +18,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -481,5 +483,76 @@ KJ_TEST("Make simultaneous IPC calls on single remote thread") KJ_EXPECT(expected == 400); } +KJ_TEST("Call async IPC method dispatched to pool thread") +{ + TestSetup setup; + ProxyClient* foo = setup.client.get(); + + // Set up the thread map exchange so the client has the server's ThreadMap, + // then call makePool to pre-allocate two server threads. + foo->initThreadMap(); + setup.server->m_impl->m_int_fn = [](int n) { return n * 2; }; + + ThreadContext& tc{g_thread_context}; + std::atomic running{3}; + std::promise pool_ready; + foo->m_context.loop->sync([&] { + auto pool_req = foo->m_context.connection->m_thread_map.makePoolRequest(); + pool_req.setName("test"); + pool_req.setCount(2); + foo->m_context.loop->m_task_set->add( + pool_req.send().then([&](auto&&) { pool_ready.set_value(); })); + }); + pool_ready.get_future().get(); + + // Send three callIntFnAsync requests with no context.thread set. + // The server should dispatch each to a pool thread. + auto client{foo->m_client}; + foo->m_context.loop->sync([&] { + for (size_t i = 0; i < running; ++i) { + auto request{client.callIntFnAsyncRequest()}; + request.initContext(); // context present but thread unset + request.setArg(static_cast(i + 1)); + foo->m_context.loop->m_task_set->add(request.send().then( + [&running, &tc, i](auto&& results) { + assert(results.getResult() == static_cast((i + 1) * 2)); + running -= 1; + tc.waiter->m_cv.notify_all(); + })); + } + }); + { + Lock lock(tc.waiter->m_mutex); + tc.waiter->wait(lock, [&running] { return running == 0; }); + } +} + +KJ_TEST("Call async IPC method without thread or pool errors correctly") +{ + TestSetup setup; + ProxyClient* foo = setup.client.get(); + setup.server->m_impl->m_fn = [] {}; + + // Send a callFnAsync request with no context.thread and no pool configured. + // The server should throw the "no thread specified and no pool configured" error. + std::promise done; + bool error_thrown{false}; + foo->m_context.loop->sync([&] { + auto request{foo->m_client.callFnAsyncRequest()}; + request.initContext(); + foo->m_context.loop->m_task_set->add( + request.send().then( + [&](auto&&) { done.set_value(); }, + [&](kj::Exception&& e) { + error_thrown = true; + KJ_EXPECT(std::string_view{e.getDescription().cStr()}.find( + "no thread specified and no pool configured") != std::string_view::npos); + done.set_value(); + })); + }); + done.get_future().get(); + KJ_EXPECT(error_thrown); +} + } // namespace test } // namespace mp