Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions tensorboard/plugins/core/core_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def define_flags(self, parser):
help="""\
What host to listen to (default: localhost). To serve to the entire local
network on both IPv4 and IPv6, see `--bind_all`, with which this option is
mutually exclusive.
mutually exclusive. May also be set to `unix://<path>` (e.g.
`unix:///tmp/tb.sock`) to listen on a Unix domain socket.
""",
)

Expand Down Expand Up @@ -390,7 +391,8 @@ def define_flags(self, parser):
Enables the SO_REUSEPORT option on the socket opened by TensorBoard's HTTP
server, for platforms that support it. This is useful in cases when a parent
process has obtained the port already and wants to delegate access to the
port to TensorBoard as a subprocess.(default: %(default)s).\
port to TensorBoard as a subprocess. Ignored when `--host` is set to a
`unix://` socket. (default: %(default)s).\
""",
)

Expand Down Expand Up @@ -707,6 +709,14 @@ def fix_flags(self, flags):
)
elif flags.host is not None and flags.bind_all:
raise FlagsError("Must not specify both --host and --bind_all.")
elif (
flags.host is not None
and flags.host.startswith("unix://")
and flags.port is not None
):
raise FlagsError(
"--host=unix://... must not be combined with --port."
)
elif (
flags.load_fast == "true" and flags.detect_file_replacement is True
):
Expand Down
18 changes: 18 additions & 0 deletions tensorboard/plugins/core/core_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
logdir="",
logdir_spec="",
path_prefix="",
port=None,
reuse_port=False,
version_tb=False,
):
Expand All @@ -72,6 +73,7 @@ def __init__(
self.logdir = logdir
self.logdir_spec = logdir_spec
self.path_prefix = path_prefix
self.port = port
self.reuse_port = reuse_port
self.version_tb = version_tb

Expand Down Expand Up @@ -132,6 +134,22 @@ def testPathPrefix_mustStartWithSlash(self):
self.assertIn("must start with slash", msg)
self.assertIn(repr("noslash"), msg)

def testHostUnixSocket_alone_isAccepted(self):
loader = core_plugin.CorePluginLoader()
for value in ("unix:///tmp/tb.sock", "unix://tb.sock"):
loader.fix_flags(FakeFlags(logdir="/tmp", host=value))

def testHostUnixSocket_conflictsWithPort(self):
loader = core_plugin.CorePluginLoader()
flag = FakeFlags(
logdir="/tmp",
host="unix:///tmp/tb.sock",
port=6006,
)
with self.assertRaises(base_plugin.FlagsError) as cm:
loader.fix_flags(flag)
self.assertIn("unix://", str(cm.exception))


class CorePluginTest(tf.test.TestCase):
def setUp(self):
Expand Down
35 changes: 29 additions & 6 deletions tensorboard/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(
) from e
assets_zip_provider = assets.get_default_assets_zip_provider()
if server_class is None:
server_class = create_port_scanning_werkzeug_server
server_class = _default_server_class
if subcommands is None:
subcommands = []
self.plugin_loaders = [
Expand Down Expand Up @@ -333,7 +333,8 @@ def _register_info(self, server):
info = manager.TensorBoardInfo(
version=version.VERSION,
start_time=int(time.time()),
port=server_url.port,
# For Unix sockets, server_url.port is None.
port=server_url.port or 0,
pid=os.getpid(),
path_prefix=self.flags.path_prefix,
logdir=self.flags.logdir or self.flags.logdir_spec,
Expand Down Expand Up @@ -476,6 +477,15 @@ def _make_server(self):
return self.server_class(app, self.flags)


def _default_server_class(wsgi_app, flags):
"""Default server factory."""

# Skip port scanning for Unix-socket servers.
if flags.host is not None and flags.host.startswith("unix://"):
return WerkzeugServer(wsgi_app, flags)
return create_port_scanning_werkzeug_server(wsgi_app, flags)


def _should_use_data_server(flags):
if flags.logdir_spec and not flags.logdir:
logger.info(
Expand Down Expand Up @@ -696,9 +706,13 @@ def __init__(self, wsgi_app, flags):
self._flags = flags
host = flags.host
port = flags.port
self._unix_socket = host is not None and host.startswith("unix://")

self._auto_wildcard = flags.bind_all
if self._auto_wildcard:
self._auto_wildcard = flags.bind_all and not self._unix_socket
if self._unix_socket:
# Werkzeug accepts host="unix://<path>" directly, port is ignored.
port = 0
elif self._auto_wildcard:
# Serve on all interfaces, and attempt to serve both IPv4 and IPv6
# traffic through one socket.
host = self._get_wildcard_address(port)
Expand All @@ -715,12 +729,14 @@ def is_port_in_use(port):
return s.connect_ex(("localhost", port)) == 0

try:
if is_port_in_use(port):
if not self._unix_socket and is_port_in_use(port):
raise TensorBoardPortInUseError(
"TensorBoard could not bind to port %d, it was already in use"
% port
)
super().__init__(host, port, wsgi_app, _WSGIRequestHandler)
if self._unix_socket:
os.chmod(self.server_address, 0o700)
except socket.error as e:
if hasattr(errno, "EACCES") and e.errno == errno.EACCES:
raise TensorBoardServerException(
Expand Down Expand Up @@ -801,7 +817,7 @@ def _get_wildcard_address(self, port):

def server_bind(self):
"""Override to set custom options on the socket."""
if self._flags.reuse_port:
if self._flags.reuse_port and not self._unix_socket:
try:
socket.SO_REUSEPORT
except AttributeError:
Expand Down Expand Up @@ -856,6 +872,13 @@ def handle_error(self, request, client_address):

def get_url(self):
if not self._url:
if self._unix_socket:
self._url = "%s%s/" % (
self._host,
self._flags.path_prefix.rstrip("/"),
)
return self._url

if self._auto_wildcard:
display_host = socket.getfqdn()
# Confirm that the connection is open, otherwise change to `localhost`
Expand Down
21 changes: 21 additions & 0 deletions tensorboard/program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

import argparse
import io
import os
import socket
import sys
import tempfile
from unittest import mock

from tensorboard import program
Expand Down Expand Up @@ -113,6 +116,7 @@ def make_flags(self, **kwargs):
flags = argparse.Namespace()
kwargs.setdefault("host", None)
kwargs.setdefault("bind_all", kwargs["host"] is None)
kwargs.setdefault("port", None)
kwargs.setdefault("reuse_port", False)
for k, v in kwargs.items():
setattr(flags, k, v)
Expand Down Expand Up @@ -168,6 +172,23 @@ def testSpecifiedHost(self):
"Neither IPv4 (127.0.0.1) nor IPv6 (::1) could be bound.",
)

def testUnixSocketHost(self):
if not hasattr(socket, "AF_UNIX"):
self.skipTest("AF_UNIX not supported on this platform")
with tempfile.TemporaryDirectory() as tmpdir:
sock_path = os.path.join(tmpdir, "tb.sock")
server = program.WerkzeugServer(
self._StubApplication(),
self.make_flags(host="unix://" + sock_path, path_prefix=""),
)
try:
self.assertTrue(os.path.exists(sock_path))
url = server.get_url()
self.assertStartsWith(url, "unix://")
self.assertIn("tb.sock", url)
finally:
server.server_close()


class SubcommandTest(tb_test.TestCase):
def setUp(self):
Expand Down