diff --git a/ingress/icmp_darwin.go b/ingress/icmp_darwin.go index 31972ac53ff..8374aa1240a 100644 --- a/ingress/icmp_darwin.go +++ b/ingress/icmp_darwin.go @@ -210,7 +210,7 @@ func (ip *icmpProxy) Serve(ctx context.Context) error { if err != nil { return err } - reply, err := parseReply(from, buf[:n]) + reply, err := parseReply(from, buf[:n], receivedTTL{}) if err != nil { ip.logger.Debug().Err(err).Str("dst", from.String()).Msg("Failed to parse ICMP reply, continue to parse as full packet") // In unit test, we found out when the listener listens on 0.0.0.0, the socket reads the full packet after @@ -231,24 +231,36 @@ func (ip *icmpProxy) Serve(ctx context.Context) error { } } +func enableReceiveTTL(conn *icmp.PacketConn, listenIP netip.Addr) error { + return nil +} + func (ip *icmpProxy) handleFullPacket(ctx context.Context, decoder *packet.ICMPDecoder, rawPacket []byte) error { - icmpPacket, err := decoder.Decode(packet.RawPacket{Data: rawPacket}) + reply, err := parseFullPacketReply(decoder, rawPacket) if err != nil { return err } + if err := ip.sendReply(ctx, reply); err != nil { + return err + } + return nil +} + +func parseFullPacketReply(decoder *packet.ICMPDecoder, rawPacket []byte) (*echoReply, error) { + icmpPacket, err := decoder.Decode(packet.RawPacket{Data: rawPacket}) + if err != nil { + return nil, err + } echo, err := getICMPEcho(icmpPacket.Message) if err != nil { - return err + return nil, err } - reply := echoReply{ + return &echoReply{ from: icmpPacket.Src, msg: icmpPacket.Message, echo: echo, - } - if ip.sendReply(ctx, &reply); err != nil { - return err - } - return nil + ttl: receivedTTLFromIPHeader(icmpPacket.TTL), + }, nil } func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error { @@ -265,10 +277,16 @@ func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error { _, span := icmpFlow.responder.ReplySpan(ctx, ip.logger) defer icmpFlow.responder.ExportSpan() - if err := icmpFlow.returnToSrc(reply); err != nil { + sent, err := icmpFlow.returnToSrc(reply) + if err != nil { tracing.EndWithErrorStatus(span, err) return err } + if !sent { + ip.logger.Debug().Str("dst", reply.from.String()).Msg("Drop ICMP echo reply because TTL expired") + tracing.End(span) + return nil + } observeICMPReply(ip.logger, span, reply.from.String(), reply.echo.ID, reply.echo.Seq) span.SetAttributes(attribute.Int("originalEchoID", icmpFlow.originalEchoID)) tracing.End(span) diff --git a/ingress/icmp_darwin_test.go b/ingress/icmp_darwin_test.go index 6cfacb9d0c2..b1547b2534b 100644 --- a/ingress/icmp_darwin_test.go +++ b/ingress/icmp_darwin_test.go @@ -7,11 +7,43 @@ import ( "net/netip" "testing" + "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "github.com/cloudflare/cloudflared/packet" ) +func TestParseFullPacketReplyUsesIPTTL(t *testing.T) { + t.Parallel() + + pk := &packet.ICMP{ + IP: &packet.IP{ + Src: localhostIP, + Dst: localhostIP, + Protocol: layers.IPProtocolICMPv4, + TTL: 37, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 12345, + Seq: 6789, + Data: []byte(t.Name()), + }, + }, + } + rawPacket, err := packet.NewEncoder().Encode(pk) + require.NoError(t, err) + + reply, err := parseFullPacketReply(packet.NewICMPDecoder(), rawPacket.Data) + require.NoError(t, err) + require.True(t, reply.ttl.ok) + require.Equal(t, uint8(37), reply.ttl.value) +} + func TestSingleEchoIDTracker(t *testing.T) { tracker := newEchoIDTracker() key := flow3Tuple{ diff --git a/ingress/icmp_linux.go b/ingress/icmp_linux.go index 0b263a8f59c..8287bd4708a 100644 --- a/ingress/icmp_linux.go +++ b/ingress/icmp_linux.go @@ -20,6 +20,9 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/tracing" @@ -177,6 +180,55 @@ func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) { } } +func enableReceiveTTL(conn *icmp.PacketConn, listenIP netip.Addr) error { + if listenIP.Is4() { + ipv4Conn := conn.IPv4PacketConn() + if ipv4Conn == nil { + return nil + } + if err := ipv4Conn.SetControlMessage(ipv4.FlagTTL, true); err != nil { + return fmt.Errorf("failed to enable IPv4 TTL control message: %w", err) + } + return nil + } + + ipv6Conn := conn.IPv6PacketConn() + if ipv6Conn == nil { + return nil + } + if err := ipv6Conn.SetControlMessage(ipv6.FlagHopLimit, true); err != nil { + return fmt.Errorf("failed to enable IPv6 hop limit control message: %w", err) + } + return nil +} + +func readICMPReply(conn *icmp.PacketConn, buf []byte) (int, net.Addr, receivedTTL, error) { + if ipv4Conn := conn.IPv4PacketConn(); ipv4Conn != nil { + n, cm, from, err := ipv4Conn.ReadFrom(buf) + if err != nil { + return 0, nil, receivedTTL{}, err + } + if cm == nil { + return n, from, receivedTTL{}, nil + } + return n, from, receivedTTLFromControlMessage(cm.TTL), nil + } + + if ipv6Conn := conn.IPv6PacketConn(); ipv6Conn != nil { + n, cm, from, err := ipv6Conn.ReadFrom(buf) + if err != nil { + return 0, nil, receivedTTL{}, err + } + if cm == nil { + return n, from, receivedTTL{}, nil + } + return n, from, receivedTTLFromControlMessage(cm.HopLimit), nil + } + + n, from, err := conn.ReadFrom(buf) + return n, from, receivedTTL{}, err +} + // Listens for ICMP response and handles error logging func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) { _, span := flow.responder.ReplySpan(ctx, ip.logger) @@ -186,7 +238,7 @@ func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf attribute.Int("originalEchoID", flow.originalEchoID), ) - n, from, err := flow.originConn.ReadFrom(buf) + n, from, ttl, err := readICMPReply(flow.originConn, buf) if err != nil { if flow.IsClosed() { tracing.EndWithErrorStatus(span, fmt.Errorf("flow was closed")) @@ -196,7 +248,7 @@ func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf tracing.EndWithErrorStatus(span, err) return true } - reply, err := parseReply(from, buf[:n]) + reply, err := parseReply(from, buf[:n], ttl) if err != nil { ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to parse ICMP reply") tracing.EndWithErrorStatus(span, err) @@ -209,11 +261,17 @@ func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf return false } - if err := flow.returnToSrc(reply); err != nil { + sent, err := flow.returnToSrc(reply) + if err != nil { ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply") tracing.EndWithErrorStatus(span, err) return false } + if !sent { + ip.logger.Debug().Str("dst", from.String()).Msg("Drop ICMP echo reply because TTL expired") + tracing.End(span) + return false + } observeICMPReply(ip.logger, span, from.String(), reply.echo.ID, reply.echo.Seq) tracing.End(span) diff --git a/ingress/icmp_posix.go b/ingress/icmp_posix.go index a5353917324..a79ac54c02e 100644 --- a/ingress/icmp_posix.go +++ b/ingress/icmp_posix.go @@ -19,10 +19,27 @@ import ( // Opens a non-privileged ICMP socket on Linux and Darwin func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) { + var ( + network string + err error + ) if listenIP.Is4() { - return icmp.ListenPacket("udp4", listenIP.String()) + network = "udp4" + } else { + network = "udp6" } - return icmp.ListenPacket("udp6", listenIP.String()) + + conn, err := icmp.ListenPacket(network, listenIP.String()) + if err != nil { + return nil, err + } + if err := enableReceiveTTL(conn, listenIP); err != nil { + if closeErr := conn.Close(); closeErr != nil { + return nil, fmt.Errorf("%w; failed to close ICMP socket after error: %v", err, closeErr) + } + return nil, err + } + return conn, nil } func netipAddr(addr net.Addr) (netip.Addr, bool) { @@ -120,8 +137,12 @@ func (ief *icmpEchoFlow) sendToDst(dst netip.Addr, msg *icmp.Message) error { } // returnToSrc rewrites the echo ID to the original echo ID from the eyeball -func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error { +func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) (bool, error) { ief.UpdateLastActive() + ttl, shouldForward := reply.ttl.forwardedTTL() + if !shouldForward { + return false, nil + } reply.echo.ID = ief.originalEchoID reply.msg.Body = reply.echo pk := packet.ICMP{ @@ -129,20 +150,21 @@ func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error { Src: reply.from, Dst: ief.src, Protocol: layers.IPProtocol(reply.msg.Type.Protocol()), - TTL: packet.DefaultTTL, + TTL: ttl, }, Message: reply.msg, } - return ief.responder.ReturnPacket(&pk) + return true, ief.responder.ReturnPacket(&pk) } type echoReply struct { from netip.Addr msg *icmp.Message echo *icmp.Echo + ttl receivedTTL } -func parseReply(from net.Addr, rawMsg []byte) (*echoReply, error) { +func parseReply(from net.Addr, rawMsg []byte, ttl receivedTTL) (*echoReply, error) { fromAddr, ok := netipAddr(from) if !ok { return nil, fmt.Errorf("cannot convert %s to netip.Addr", from) @@ -163,6 +185,7 @@ func parseReply(from net.Addr, rawMsg []byte) (*echoReply, error) { from: fromAddr, msg: msg, echo: echo, + ttl: ttl, }, nil } diff --git a/ingress/icmp_posix_test.go b/ingress/icmp_posix_test.go index 6231e1b9e64..71d671b84b0 100644 --- a/ingress/icmp_posix_test.go +++ b/ingress/icmp_posix_test.go @@ -18,6 +18,75 @@ import ( "github.com/cloudflare/cloudflared/packet" ) +func TestReturnToSrcUsesReplyTTL(t *testing.T) { + t.Parallel() + + const originalEchoID = 42573 + muxer := newMockMuxer(1) + responder := newPacketResponder(muxer, 0, packet.NewEncoder()) + flow := newICMPEchoFlow(localhostIP, func() error { return nil }, nil, responder, 0, originalEchoID) + + sent, err := flow.returnToSrc(&echoReply{ + from: localhostIP, + msg: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + }, + echo: &icmp.Echo{ + ID: 12345, + Seq: 6789, + Data: []byte(t.Name()), + }, + ttl: receivedTTLFromIPHeader(42), + }) + require.NoError(t, err) + require.True(t, sent) + + resp := <-muxer.cfdToEdge + decoder := packet.NewICMPDecoder() + decoded, err := decoder.Decode(packet.RawPacket{Data: resp.Payload()}) + require.NoError(t, err) + require.Equal(t, uint8(41), decoded.TTL) + require.Equal(t, localhostIP, decoded.Src) + require.Equal(t, localhostIP, decoded.Dst) + require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type) + require.Equal(t, &icmp.Echo{ + ID: originalEchoID, + Seq: 6789, + Data: []byte(t.Name()), + }, decoded.Body) +} + +func TestReturnToSrcDropsExpiredReplyTTL(t *testing.T) { + t.Parallel() + + muxer := newMockMuxer(1) + responder := newPacketResponder(muxer, 0, packet.NewEncoder()) + flow := newICMPEchoFlow(localhostIP, func() error { return nil }, nil, responder, 0, 42573) + + sent, err := flow.returnToSrc(&echoReply{ + from: localhostIP, + msg: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + }, + echo: &icmp.Echo{ + ID: 12345, + Seq: 6789, + Data: []byte(t.Name()), + }, + ttl: receivedTTLFromIPHeader(1), + }) + require.NoError(t, err) + require.False(t, sent) + + select { + case pk := <-muxer.cfdToEdge: + t.Fatalf("received unexpected ICMP reply: %+v", pk) + default: + } +} + func TestFunnelIdleTimeout(t *testing.T) { defer leaktest.Check(t)() diff --git a/ingress/icmp_windows.go b/ingress/icmp_windows.go index 23c3eb50e8a..c49af6feb43 100644 --- a/ingress/icmp_windows.go +++ b/ingress/icmp_windows.go @@ -285,12 +285,17 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM responder.ExportSpan() _, replySpan := responder.ReplySpan(ctx, ip.logger) - err = ip.handleEchoReply(pk, echo, resp, responder) + sent, err := ip.handleEchoReply(pk, echo, resp, responder) if err != nil { ip.logger.Err(err).Msg("Failed to send ICMP reply") tracing.EndWithErrorStatus(replySpan, err) return errors.Wrap(err, "failed to handle ICMP echo reply") } + if !sent { + ip.logger.Debug().Str("dst", pk.Dst.String()).Msg("Drop ICMP echo reply because TTL expired") + tracing.End(replySpan) + return nil + } observeICMPReply(ip.logger, replySpan, pk.Dst.String(), echo.ID, echo.Seq) replySpan.SetAttributes( attribute.Int64("rtt", int64(resp.rtt())), @@ -300,7 +305,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM return nil } -func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder ICMPResponder) error { +func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder ICMPResponder) (bool, error) { var replyType icmp.Type if request.Dst.Is4() { replyType = ipv4.ICMPTypeEchoReply @@ -308,12 +313,21 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, r replyType = ipv6.ICMPTypeEchoReply } + ttl := packet.DefaultTTL + if received, ok := resp.ttl(); ok { + forwarded, shouldForward := receivedTTLFromIPHeader(received).forwardedTTL() + if !shouldForward { + return false, nil + } + ttl = forwarded + } + pk := packet.ICMP{ IP: &packet.IP{ Src: request.Dst, Dst: request.Src, Protocol: layers.IPProtocol(request.Type.Protocol()), - TTL: packet.DefaultTTL, + TTL: ttl, }, Message: &icmp.Message{ Type: replyType, @@ -325,7 +339,7 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, r }, }, } - return responder.ReturnPacket(&pk) + return true, responder.ReturnPacket(&pk) } func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) (echoResp, error) { @@ -410,6 +424,7 @@ type echoResp interface { status() ipStatus rtt() uint32 payload() []byte + ttl() (uint8, bool) } type echoV4Resp struct { @@ -429,6 +444,10 @@ func (r *echoV4Resp) payload() []byte { return r.data } +func (r *echoV4Resp) ttl() (uint8, bool) { + return r.reply.Options.TTL, true +} + func newEchoV4Resp(replyBuf []byte) (*echoV4Resp, error) { if len(replyBuf) == 0 { return nil, fmt.Errorf("reply buffer is empty") @@ -527,6 +546,10 @@ func (r *echoV6Resp) payload() []byte { return r.data } +func (r *echoV6Resp) ttl() (uint8, bool) { + return 0, false +} + func newEchoV6Resp(replyBuf []byte, dataSize int) (*echoV6Resp, error) { if len(replyBuf) == 0 { return nil, fmt.Errorf("reply buffer is empty") diff --git a/ingress/icmp_windows_test.go b/ingress/icmp_windows_test.go index 5cd53a3074c..9842f1883fc 100644 --- a/ingress/icmp_windows_test.go +++ b/ingress/icmp_windows_test.go @@ -12,8 +12,11 @@ import ( "time" "unsafe" + "github.com/google/gopacket/layers" "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "github.com/cloudflare/cloudflared/packet" "github.com/stretchr/testify/require" ) @@ -125,6 +128,51 @@ func TestParseEchoV6Reply(t *testing.T) { } } +func TestHandleEchoReplyUsesIPv4TTL(t *testing.T) { + t.Parallel() + + echo := &icmp.Echo{ + ID: 6193, + Seq: 25712, + Data: []byte(t.Name()), + } + request := &packet.ICMP{ + IP: &packet.IP{ + Src: localhostIP, + Dst: netip.MustParseAddr("192.0.2.200"), + Protocol: layers.IPProtocolICMPv4, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: echo, + }, + } + resp := &echoV4Resp{ + reply: &echoReply{ + Status: success, + Options: ipOption{ + TTL: 59, + }, + }, + data: []byte(t.Name()), + } + muxer := newMockMuxer(1) + responder := newPacketResponder(muxer, 0, packet.NewEncoder()) + + sent, err := (&icmpProxy{}).handleEchoReply(request, echo, resp, responder) + require.NoError(t, err) + require.True(t, sent) + + pk := <-muxer.cfdToEdge + decoded, err := packet.NewICMPDecoder().Decode(packet.RawPacket{Data: pk.Payload()}) + require.NoError(t, err) + require.Equal(t, uint8(58), decoded.TTL) + require.Equal(t, request.Dst, decoded.Src) + require.Equal(t, request.Src, decoded.Dst) + require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type) +} + // TestSendEchoErrors makes sure icmpSendEcho handles error cases func TestSendEchoErrors(t *testing.T) { testSendEchoErrors(t, netip.IPv4Unspecified()) diff --git a/ingress/origin_icmp_proxy.go b/ingress/origin_icmp_proxy.go index 981b86671bc..677933ad5e1 100644 --- a/ingress/origin_icmp_proxy.go +++ b/ingress/origin_icmp_proxy.go @@ -3,6 +3,7 @@ package ingress import ( "context" "fmt" + "math" "net/netip" "time" @@ -27,6 +28,38 @@ var ( errPacketNil = fmt.Errorf("packet is nil") ) +type receivedTTL struct { + value uint8 + ok bool +} + +func receivedTTLFromControlMessage(value int) receivedTTL { + if value <= 0 || value > math.MaxUint8 { + return receivedTTL{} + } + return receivedTTL{ + value: uint8(value), + ok: true, + } +} + +func receivedTTLFromIPHeader(value uint8) receivedTTL { + return receivedTTL{ + value: value, + ok: true, + } +} + +func (ttl receivedTTL) forwardedTTL() (uint8, bool) { + if !ttl.ok { + return packet.DefaultTTL, true + } + if ttl.value <= 1 { + return 0, false + } + return ttl.value - 1, true +} + // ICMPRouterServer is a parent interface over-top of ICMPRouter that allows for the operation of the proxy origin listeners. type ICMPRouterServer interface { ICMPRouter