Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions guest/sshd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,32 +274,12 @@ func (s *Server) handleConnection(netConn net.Conn) {
}

// handleGlobalRequests processes connection-level SSH requests.
// It handles agent forwarding requests when enabled and discards
// all other global requests.
func (s *Server) handleGlobalRequests(reqs <-chan *ssh.Request, conn *ssh.ServerConn) {
// It rejects all global requests; session-specific requests like
// agent forwarding are handled in handleSession.
func (s *Server) handleGlobalRequests(reqs <-chan *ssh.Request, _ *ssh.ServerConn) {
for req := range reqs {
switch req.Type {
case "auth-agent-req@openssh.com":
if s.cfg.AgentForwarding {
s.setAgentForwarding(conn, true)
s.logger.Info("agent forwarding enabled",
"remote", conn.RemoteAddr(),
)
if req.WantReply {
_ = req.Reply(true, nil)
}
} else {
s.logger.Debug("agent forwarding rejected (disabled)",
"remote", conn.RemoteAddr(),
)
if req.WantReply {
_ = req.Reply(false, nil)
}
}
default:
if req.WantReply {
_ = req.Reply(false, nil)
}
if req.WantReply {
_ = req.Reply(false, nil)
}
}
}
Expand Down
40 changes: 28 additions & 12 deletions guest/sshd/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// generateTestKeyPair creates an ECDSA P-256 key pair for testing and
Expand Down Expand Up @@ -327,10 +328,13 @@ func TestAgentForwardingDisabled(t *testing.T) {

client := dialSSH(t, addr, signer)

// Request agent forwarding — should be rejected.
ok, _, err := client.SendRequest("auth-agent-req@openssh.com", true, nil)
session, err := client.NewSession()
require.NoError(t, err)
assert.False(t, ok, "agent forwarding should be rejected when disabled")
defer func() { _ = session.Close() }()

// Request agent forwarding via the real API — should be rejected.
err = agent.RequestAgentForwarding(session)
assert.Error(t, err, "agent forwarding should be rejected when disabled")
}

func TestAgentForwardingEnabled(t *testing.T) {
Expand All @@ -352,10 +356,23 @@ func TestAgentForwardingEnabled(t *testing.T) {

client := dialSSH(t, addr, signer)

// Request agent forwarding — should be accepted.
ok, _, err := client.SendRequest("auth-agent-req@openssh.com", true, nil)
session, err := client.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()

// Request agent forwarding via the real API — should be accepted.
err = agent.RequestAgentForwarding(session)
require.NoError(t, err, "agent forwarding should be accepted when enabled")

// Verify the flag was set by running a command on a second session.
session2, err := client.NewSession()
require.NoError(t, err)
assert.True(t, ok, "agent forwarding should be accepted when enabled")
defer func() { _ = session2.Close() }()

output, err := session2.CombinedOutput("echo ${SSH_AUTH_SOCK:-unset}")
require.NoError(t, err)
result := strings.TrimSpace(string(output))
assert.Contains(t, result, "/tmp/ssh-", "agent socket should be set on connection after forwarding request")
}

func TestAgentSocketCreated(t *testing.T) {
Expand All @@ -377,16 +394,15 @@ func TestAgentSocketCreated(t *testing.T) {

client := dialSSH(t, addr, signer)

// Request agent forwarding.
ok, _, err := client.SendRequest("auth-agent-req@openssh.com", true, nil)
require.NoError(t, err)
require.True(t, ok)

// Run a command that checks if SSH_AUTH_SOCK is set.
// Request agent forwarding and run a command on the same session,
// which is the real client flow: auth-agent-req arrives before exec.
session, err := client.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()

err = agent.RequestAgentForwarding(session)
require.NoError(t, err)

output, err := session.CombinedOutput("echo $SSH_AUTH_SOCK")
require.NoError(t, err)

Expand Down
14 changes: 14 additions & 0 deletions guest/sshd/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ func (s *Server) handleSession(ch ssh.Channel, requests <-chan *ssh.Request, con
replyRequest(req, false)
}

case "auth-agent-req@openssh.com":
if s.cfg.AgentForwarding {
s.setAgentForwarding(conn, true)
s.logger.Info("agent forwarding enabled",
"remote", conn.RemoteAddr(),
)
replyRequest(req, true)
} else {
s.logger.Debug("agent forwarding rejected (disabled)",
"remote", conn.RemoteAddr(),
)
replyRequest(req, false)
}

case "exec":
var payload struct {
Command string
Expand Down
Loading