diff --git a/cmd/root.go b/cmd/root.go index a2a145e..a232016 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -105,6 +105,7 @@ func init() { rootCmd.Flags().StringVar(&config.UITLSKey, "ui-tls-key", config.UITLSKey, "TLS key for web UI (HTTPS) - requires ui-tls-cert") rootCmd.Flags().StringVar(&server.AccessControlAllowOrigin, "api-cors", server.AccessControlAllowOrigin, "Set CORS origin(s) for the API, comma-separated (eg: example.com,foo.com)") rootCmd.Flags().BoolVar(&config.BlockRemoteCSSAndFonts, "block-remote-css-and-fonts", config.BlockRemoteCSSAndFonts, "Block access to remote CSS & fonts") + rootCmd.Flags().BoolVar(&config.AllowInternalHTTPRequests, "allow-internal-http-requests", config.AllowInternalHTTPRequests, "Allow link-checker & screenshots to access internal IP addresses") rootCmd.Flags().StringVar(&config.EnableSpamAssassin, "enable-spamassassin", config.EnableSpamAssassin, "Enable integration with SpamAssassin") rootCmd.Flags().BoolVar(&config.AllowUntrustedTLS, "allow-untrusted-tls", config.AllowUntrustedTLS, "Do not verify HTTPS certificates (link checker & screenshots)") rootCmd.Flags().BoolVar(&config.DisableHTTPCompression, "disable-http-compression", config.DisableHTTPCompression, "Disable HTTP compression support (web UI & API)") @@ -250,6 +251,9 @@ func initConfigFromEnv() { if getEnabledFromEnv("MP_BLOCK_REMOTE_CSS_AND_FONTS") { config.BlockRemoteCSSAndFonts = true } + if getEnabledFromEnv("MP_ALLOW_INTERNAL_HTTP_REQUESTS") { + config.AllowInternalHTTPRequests = true + } if len(os.Getenv("MP_ENABLE_SPAMASSASSIN")) > 0 { config.EnableSpamAssassin = os.Getenv("MP_ENABLE_SPAMASSASSIN") } diff --git a/config/config.go b/config/config.go index 5ad3dbf..ff2bfea 100644 --- a/config/config.go +++ b/config/config.go @@ -127,6 +127,10 @@ var ( // BlockRemoteCSSAndFonts used to disable remote CSS & fonts BlockRemoteCSSAndFonts = false + // AllowInternalHTTPRequests will allow HTTP requests to internal IP addresses (e.g., loopback, private, link-local, or multicast) when set to true. + // This policy applies to both link checking and screenshot generation (proxy) features and is disabled by default for security reasons. + AllowInternalHTTPRequests = false + // CLITagsArg is used to map the CLI args CLITagsArg string diff --git a/internal/linkcheck/status.go b/internal/linkcheck/status.go index e3b59c8..703d1b1 100644 --- a/internal/linkcheck/status.go +++ b/internal/linkcheck/status.go @@ -1,14 +1,20 @@ package linkcheck import ( + "context" "crypto/tls" + "errors" + "fmt" + "net" "net/http" "regexp" + "strings" "sync" "time" "github.com/axllent/mailpit/config" "github.com/axllent/mailpit/internal/logger" + "github.com/axllent/mailpit/internal/tools" ) func getHTTPStatuses(links []string, followRedirects bool) []Link { @@ -34,6 +40,10 @@ func getHTTPStatuses(links []string, followRedirects bool) []Link { if err != nil { l.StatusCode = 0 l.Status = httpErrorSummary(err) + if strings.Contains(l.Status, "private/reserved address") { + l.Status = "Blocked private/reserved address" + l.StatusCode = 451 + } } else { l.StatusCode = code l.Status = http.StatusText(code) @@ -57,23 +67,37 @@ func getHTTPStatuses(links []string, followRedirects bool) []Link { // Do a HEAD request to return HTTP status code func doHead(link string, followRedirects bool) (int, error) { + if !tools.IsValidLinkURL(link) { + return 0, fmt.Errorf("invalid URL: %s", link) + } - timeout := time.Duration(10 * time.Second) + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } - tr := &http.Transport{} + tr := &http.Transport{ + DialContext: safeDialContext(dialer), + } if config.AllowUntrustedTLS { tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec } client := http.Client{ - Timeout: timeout, + Timeout: 10 * time.Second, Transport: tr, CheckRedirect: func(req *http.Request, via []*http.Request) error { - if followRedirects { - return nil + if len(via) >= 3 { + return errors.New("too many redirects") } - return http.ErrUseLastResponse + if !followRedirects { + return http.ErrUseLastResponse + } + if !tools.IsValidLinkURL(req.URL.String()) { + return fmt.Errorf("blocked redirect to invalid URL: %s", req.URL) + } + return nil }, } @@ -92,7 +116,6 @@ func doHead(link string, followRedirects bool) (int, error) { } return 0, err - } return res.StatusCode, nil @@ -107,8 +130,33 @@ func httpErrorSummary(err error) string { if !re.MatchString(e) { return e } - parts := re.FindAllStringSubmatch(e, -1) return parts[0][len(parts[0])-1] } + +// SafeDialContext is a custom dialer that checks if the resolved IP addresses are internal before allowing the connection. +func safeDialContext(dialer *net.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + if !config.AllowInternalHTTPRequests { + for _, ip := range ips { + if tools.IsInternalIP(ip.IP) { + logger.Log().Warnf("[link-check] Blocked HEAD request to private/reserved address: %s (%s)", host, ip) + return nil, fmt.Errorf("blocked request to %s (%s): private/reserved address", host, ip) + } + } + } + + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + } +} diff --git a/internal/tools/net.go b/internal/tools/net.go new file mode 100644 index 0000000..bc7f7f9 --- /dev/null +++ b/internal/tools/net.go @@ -0,0 +1,28 @@ +package tools + +import ( + "net" + "net/url" +) + +// IsInternalIP checks if the given IP address is an internal IP address (e.g., loopback, private, link-local, or multicast). +// IsLoopback — 127.0.0.0/8, ::1 +// IsPrivate — 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 +// IsLinkLocalUnicast — 169.254.0.0/16, fe80::/10 (covers cloud metadata 169.254.169.254) +// IsLinkLocalMulticast — 224.0.0.0/24, ff02::/16 +// IsUnspecified — 0.0.0.0, :: +// IsMulticast — 224.0.0.0/4, ff00::/8 +func IsInternalIP(ip net.IP) bool { + return ip.IsLoopback() || + ip.IsPrivate() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsUnspecified() || + ip.IsMulticast() +} + +// IsValidLinkURL checks if the provided string is a valid URL with http or https scheme and a non-empty hostname. +func IsValidLinkURL(str string) bool { + u, err := url.Parse(str) + return err == nil && (u.Scheme == "http" || u.Scheme == "https") && u.Hostname() != "" +} diff --git a/server/handlers/proxy.go b/server/handlers/proxy.go index 3b95134..ab0fc51 100644 --- a/server/handlers/proxy.go +++ b/server/handlers/proxy.go @@ -2,10 +2,13 @@ package handlers import ( + "context" "crypto/tls" "encoding/base64" + "errors" "fmt" "io" + "net" "net/http" "net/url" "regexp" @@ -96,21 +99,37 @@ func ProxyHandler(w http.ResponseWriter, r *http.Request) { return } - if !linkRe.MatchString(uri) { - logger.Log().Warnf("[proxy] invalid request %s", uri) - httpError(w, "Error: invalid request") + if !linkRe.MatchString(uri) || !tools.IsValidLinkURL(uri) { + logger.Log().Warnf("[proxy] invalid URL %s", uri) + httpError(w, "Error: invalid URL") return } - tr := &http.Transport{} + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + } + + tr := &http.Transport{ + DialContext: safeDialContext(dialer), + } if config.AllowUntrustedTLS { tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec } client := &http.Client{ - Transport: tr, Timeout: 10 * time.Second, + Transport: tr, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 3 { + return errors.New("too many redirects") + } + if !tools.IsValidLinkURL(req.URL.String()) { + return fmt.Errorf("blocked redirect to invalid URL: %s", req.URL) + } + return nil + }, } req, err := http.NewRequest("GET", uri, nil) @@ -357,3 +376,28 @@ func supportedProxyContentType(ct string) bool { return false } + +// SafeDialContext is a custom dialer that checks if the resolved IP addresses are internal before allowing the connection. +func safeDialContext(dialer *net.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + + if !config.AllowInternalHTTPRequests { + for _, ip := range ips { + if tools.IsInternalIP(ip.IP) { + return nil, fmt.Errorf("blocked request to %s (%s): private/reserved address", host, ip) + } + } + } + + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + } +}