Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions acme/acme.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,24 @@ func IssueCertificates(cacheDir, email, challengeType string, domains []string,
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(_ certmagic.Certificate) (*certmagic.Config, error) {
return &certmagic.Config{
RenewalWindowRatio: 0,
MustStaple: false,
OCSP: certmagic.OCSPConfig{},
Storage: &certmagic.FileStorage{Path: cacheDir},
Storage: &certmagic.FileStorage{Path: cacheDir},
}, nil
},
OCSPCheckInterval: 0,
RenewCheckInterval: 0,
Capacity: 0,
})

cfg := certmagic.New(cache, certmagic.Config{
RenewalWindowRatio: 0,
MustStaple: false,
OCSP: certmagic.OCSPConfig{},
Storage: &certmagic.FileStorage{Path: cacheDir},
Storage: &certmagic.FileStorage{Path: cacheDir},
})

myAcme := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
CA: certmagic.LetsEncryptProductionCA,
TestCA: certmagic.LetsEncryptStagingCA,
Email: email,
Agreed: true,
DisableHTTPChallenge: false,
DisableTLSALPNChallenge: false,
ListenHost: "0.0.0.0",
AltHTTPPort: altHTTPPort,
AltTLSALPNPort: altTLSAlpnPort,
CertObtainTimeout: time.Second * 240,
PreferredChains: certmagic.ChainPreference{},
CA: certmagic.LetsEncryptProductionCA,
TestCA: certmagic.LetsEncryptStagingCA,
Email: email,
Agreed: true,
ListenHost: "0.0.0.0",
AltHTTPPort: altHTTPPort,
AltTLSALPNPort: altTLSAlpnPort,
CertObtainTimeout: time.Second * 240,
})

if !useProduction {
Expand Down
12 changes: 10 additions & 2 deletions attributes/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func Get(r *http.Request, key string) any {
return nil
}

return v.(attrs).get(key)
a, ok := v.(attrs)
if !ok {
return nil
}
return a.get(key)
}

// Set sets the key to value. It replaces any existing
Expand All @@ -83,7 +87,11 @@ func Set(r *http.Request, key string, value string) error {
return errors.New("unable to find `psr:attributes` context key")
}

v.(attrs).set(key, value)
a, ok := v.(attrs)
if !ok {
return errors.New("unexpected type stored under `psr:attributes` context key")
}
Comment thread
rustatian marked this conversation as resolved.
a.set(key, value)
return nil
}

Expand Down
12 changes: 2 additions & 10 deletions handler/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ const (
// FetchIP extracts the client IP from net/http's RemoteAddr ("host:port"
// or a bare IP). Returns the empty string for unparseable input.
func FetchIP(pair string, log *slog.Logger) string {
if !strings.ContainsRune(pair, ':') {
return pair
}

addr, _, err := net.SplitHostPort(pair)
if err == nil {
return addr
Expand All @@ -44,9 +40,7 @@ func FetchIP(pair string, log *slog.Logger) string {
// URI returns the fully-qualified request URI, stripping CR/LF to prevent
// header smuggling via the URL.
func URI(r *http.Request) string {
uri := r.URL.String()
uri = strings.ReplaceAll(uri, "\n", "")
uri = strings.ReplaceAll(uri, "\r", "")
uri := strings.ReplaceAll(strings.ReplaceAll(r.URL.String(), "\n", ""), "\r", "")

if r.URL.Host != "" {
return uri
Expand Down Expand Up @@ -88,9 +82,7 @@ func extractCookies(r *http.Request) map[string]string {

// cleanRawQuery strips CR/LF from the URL raw query before exposing it to PHP.
func cleanRawQuery(q string) string {
q = strings.ReplaceAll(q, "\n", "")
q = strings.ReplaceAll(q, "\r", "")
return q
return strings.ReplaceAll(strings.ReplaceAll(q, "\n", ""), "\r", "")
}

// populateBody fills req.Body / req.Parsed based on the request content-type.
Expand Down
8 changes: 4 additions & 4 deletions handler/uploads.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package handler

import (
"encoding/json"
"errors"
"io"
"io/fs"
"log/slog"
"mime/multipart"
"os"
Expand Down Expand Up @@ -166,8 +168,6 @@ func (f *FileUpload) Open(dir string, forbid, allow map[string]struct{}) error {
// exists if file exists.
func exists(path string) bool {
// path is RR-generated TempFilename, not user-controlled.
if _, err := os.Stat(path); os.IsNotExist(err) { //nolint:gosec // G703
return false
}
return true
_, err := os.Stat(path) //nolint:gosec // G703
return !errors.Is(err, fs.ErrNotExist)
}
65 changes: 30 additions & 35 deletions middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
var _ io.ReadCloser = (*wrapper)(nil)
var _ http.ResponseWriter = (*wrapper)(nil)

// stripCRLF removes CR/LF to prevent log injection (CWE-117).
func stripCRLF(s string) string {
return strings.ReplaceAll(strings.ReplaceAll(s, "\n", ""), "\r", "")
}

type wrapper struct {
io.ReadCloser
read int
Expand All @@ -27,7 +32,6 @@ type wrapper struct {

w http.ResponseWriter
code int
data []byte
}

func (w *wrapper) Read(b []byte) (int, error) {
Expand Down Expand Up @@ -85,7 +89,6 @@ func (w *wrapper) reset() {
w.wc = false
w.write = 0
w.w = nil
w.data = nil
w.ReadCloser = nil
}

Expand Down Expand Up @@ -128,8 +131,7 @@ func (l *lm) Log(next http.Handler, accessLogs bool) http.Handler {
}

func (l *lm) writeLog(accessLog bool, r *http.Request, bw *wrapper, start time.Time) {
switch accessLog {
case false:
if !accessLog {
l.log.Info("http log",
"status", bw.code,
"method", r.Method,
Expand All @@ -140,38 +142,31 @@ func (l *lm) writeLog(accessLog bool, r *http.Request, bw *wrapper, start time.T
"write_bytes", bw.write,
"start", start,
"elapsed", time.Since(start).Milliseconds())
case true:
// external/cwe/cwe-117
usrA := r.UserAgent()
usrA = strings.ReplaceAll(usrA, "\n", "")
usrA = strings.ReplaceAll(usrA, "\r", "")

rfr := r.Referer()
rfr = strings.ReplaceAll(rfr, "\n", "")
rfr = strings.ReplaceAll(rfr, "\r", "")

rq := r.URL.RawQuery
rq = strings.ReplaceAll(rq, "\n", "")
rq = strings.ReplaceAll(rq, "\r", "")

l.log.Info("http access log",
"read_bytes", bw.read,
"write_bytes", bw.write,
"status", bw.code,
"method", r.Method,
"URI", r.RequestURI,
"URL", r.URL.String(),
"remote_address", r.RemoteAddr,
"query", rq,
"content_len", r.ContentLength,
"host", r.Host,
"user_agent", usrA,
"referer", rfr,
"time_local", time.Now().Format("02/Jan/06:15:04:05 -0700"),
"request_time", time.Now(),
"start", start,
"elapsed", time.Since(start).Milliseconds())
return
}

// external/cwe/cwe-117
usrA := stripCRLF(r.UserAgent())
rfr := stripCRLF(r.Referer())
rq := stripCRLF(r.URL.RawQuery)

l.log.Info("http access log",
"read_bytes", bw.read,
"write_bytes", bw.write,
"status", bw.code,
"method", r.Method,
"URI", r.RequestURI,
"URL", r.URL.String(),
"remote_address", r.RemoteAddr,
"query", rq,
"content_len", r.ContentLength,
"host", r.Host,
"user_agent", usrA,
"referer", rfr,
"time_local", start.Format("02/Jan/06:15:04:05 -0700"),
"request_time", start,
"start", start,
"elapsed", time.Since(start).Milliseconds())
}

func (l *lm) getW(w http.ResponseWriter) *wrapper {
Expand Down
8 changes: 2 additions & 6 deletions middleware/maxRequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@ import (

func MaxRequestSize(next http.Handler, maxReqSize uint64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// validating request size

r2 := r.Clone(r.Context())
r2.Body = http.MaxBytesReader(w, r2.Body, int64(maxReqSize)) //nolint:gosec

// use max_request_size limit in megabytes
next.ServeHTTP(w, r2)
r.Body = http.MaxBytesReader(w, r.Body, int64(maxReqSize)) //nolint:gosec
next.ServeHTTP(w, r)
})
}
32 changes: 13 additions & 19 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/roadrunner-server/http/v6/handler"
"github.com/roadrunner-server/http/v6/proxy"
"github.com/roadrunner-server/http/v6/servers"
"github.com/roadrunner-server/pool/v2/pool/static_pool"
"github.com/roadrunner-server/pool/v2/state/process"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
jprop "go.opentelemetry.io/contrib/propagators/jaeger"
Expand Down Expand Up @@ -128,12 +127,18 @@ func (p *Plugin) Serve() chan error {
p.mu.Lock()
defer p.mu.Unlock()

var err error
p.pool, err = p.server.NewPool(context.Background(), p.cfg.Pool, map[string]string{RrMode: RrModeHTTP}, p.log)
// NewPool returns a concrete *static_pool.Pool; assign it to the api.Pool
// interface field only when non-nil so p.pool never holds a typed nil
// (which would make the p.pool != nil guard in Stop spuriously true and
// panic inside Destroy on a nil receiver).
np, err := p.server.NewPool(context.Background(), p.cfg.Pool, map[string]string{RrMode: RrModeHTTP}, p.log)
if err != nil {
errCh <- err
return errCh
}
if np != nil {
p.pool = np
}

// request queue + worker-facing ConnectRPC server
p.queue = proxy.NewQueue(p.cfg.Proxy.InboxSize)
Expand Down Expand Up @@ -161,14 +166,12 @@ func (p *Plugin) Serve() chan error {
p.applyBundledMiddleware()

// start all servers
for i := range p.servers {
go func(idx int) {
errSt := p.servers[idx].Serve(p.mdwr, p.cfg.Middleware)
if errSt != nil {
for _, srv := range p.servers {
go func(s servers.InternalServer[any]) {
if errSt := s.Serve(p.mdwr, p.cfg.Middleware); errSt != nil {
errCh <- errSt
return
}
}(i)
}(srv)
}

return errCh
Expand Down Expand Up @@ -201,14 +204,7 @@ func (p *Plugin) Stop(ctx context.Context) error {
}

if p.pool != nil {
switch pp := p.pool.(type) {
case *static_pool.Pool:
if pp != nil {
pp.Destroy(ctx)
}
default:
// pool is nil, nothing to do
}
p.pool.Destroy(ctx)
}

doneCh <- struct{}{}
Expand Down Expand Up @@ -242,8 +238,6 @@ func (p *Plugin) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.handler.ServeHTTP(w, r)
p.mu.RUnlock()

_ = r.Body.Close()

if span != nil {
span.End()
}
Expand Down
22 changes: 12 additions & 10 deletions servers/https/config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package https

import (
"errors"
"io/fs"
"net"
"os"
"strconv"

"github.com/roadrunner-server/errors"
rrerrors "github.com/roadrunner-server/errors"
"github.com/roadrunner-server/http/v6/acme"
)

Expand Down Expand Up @@ -95,11 +97,11 @@ func (s *SSL) EnableACME() bool {
}

func (s *SSL) Valid() error {
const op = errors.Op("ssl_valid")
const op = rrerrors.Op("ssl_valid")

host, portStr, err := net.SplitHostPort(s.Address)
if err != nil {
return errors.E(op, err)
return rrerrors.E(op, err)
}
if host == "" {
s.host = "127.0.0.1"
Expand All @@ -108,23 +110,23 @@ func (s *SSL) Valid() error {
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return errors.E(op, err)
return rrerrors.E(op, err)
}
s.Port = int(port)

// the user use they own certificates
if s.Acme == nil {
if _, err := os.Stat(s.Key); err != nil {
if os.IsNotExist(err) {
return errors.E(op, errors.Errorf("key file '%s' does not exists", s.Key))
if errors.Is(err, fs.ErrNotExist) {
return rrerrors.E(op, rrerrors.Errorf("key file '%s' does not exists", s.Key))
}

return err
}

if _, err := os.Stat(s.Cert); err != nil {
if os.IsNotExist(err) {
return errors.E(op, errors.Errorf("cert file '%s' does not exists", s.Cert))
if errors.Is(err, fs.ErrNotExist) {
return rrerrors.E(op, rrerrors.Errorf("cert file '%s' does not exists", s.Cert))
}

return err
Expand All @@ -134,8 +136,8 @@ func (s *SSL) Valid() error {
// RootCA is optional, but if provided - check it
if s.RootCA != "" {
if _, err := os.Stat(s.RootCA); err != nil {
if os.IsNotExist(err) {
return errors.E(op, errors.Errorf("root ca path provided, but path '%s' does not exists", s.RootCA))
if errors.Is(err, fs.ErrNotExist) {
return rrerrors.E(op, rrerrors.Errorf("root ca path provided, but path '%s' does not exists", s.RootCA))
}
return err
}
Expand Down
Loading
Loading