Chore: Add support for multi-origin CORS settings and apply to events websocket (#630)

This commit is contained in:
Ralph Slooten
2026-01-31 22:54:32 +13:00
parent ff47ba96b8
commit a63bcd9bd3
3 changed files with 148 additions and 6 deletions

127
server/cors.go Normal file
View File

@@ -0,0 +1,127 @@
package server
import (
"net/http"
"net/url"
"sort"
"strings"
"github.com/axllent/mailpit/internal/logger"
)
var (
// AccessControlAllowOrigin CORS policy - set with flags/env
AccessControlAllowOrigin string
// CorsAllowOrigins are optional allowed origins by hostname, set via setCORSOrigins().
corsAllowOrigins = make(map[string]bool)
)
// equalASCIIFold reports whether s and t, interpreted as UTF-8 strings, are equal
// under Unicode case folding, ignoring any difference in length.
func asciiFoldString(s string) string {
b := make([]byte, len(s))
for i := range s {
b[i] = toLowerASCIIFold(s[i])
}
return string(b)
}
// toLowerASCIIFold returns the Unicode case-folded equivalent of the ASCII character c.
// It is equivalent to the Unicode 13.0.0 function foldCase(c, CaseFoldingMapping).
func toLowerASCIIFold(c byte) byte {
if 'A' <= c && c <= 'Z' {
return c + 'a' - 'A'
}
return c
}
// CorsOriginAccessControl checks if the request origin is allowed based on the configured CORS origins.
func corsOriginAccessControl(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) != 0 {
u, err := url.Parse(origin[0])
if err != nil {
logger.Log().Errorf("CORS origin parse error: %v", err)
return false
}
_, allAllowed := corsAllowOrigins["*"]
// allow same origin || is "*" is defined as an origin
if asciiFoldString(u.Host) == asciiFoldString(r.Host) || allAllowed {
return true
}
originHostFold := asciiFoldString(u.Hostname())
if corsAllowOrigins[originHostFold] {
return true
}
return false
}
return true
}
// SetCORSOrigins sets the allowed CORS origins from a comma-separated string.
// It does not consider port or protocol, only the hostname.
func setCORSOrigins() {
corsAllowOrigins = make(map[string]bool)
hosts := extractOrigins(AccessControlAllowOrigin)
for _, host := range hosts {
corsAllowOrigins[asciiFoldString(host)] = true
}
if _, wildCard := corsAllowOrigins["*"]; wildCard {
// reset to just wildcard
corsAllowOrigins = make(map[string]bool)
corsAllowOrigins["*"] = true
logger.Log().Info("[cors] all origins are allowed due to wildcard \"*\"")
} else {
keys := make([]string, 0)
for k := range corsAllowOrigins {
keys = append(keys, k)
}
sort.Strings(keys)
logger.Log().Infof("[cors] allowed API origins: %v", strings.Join(keys, ", "))
}
}
// extractOrigins extracts and returns a sorted list of origins from a comma-separated string.
func extractOrigins(str string) []string {
origins := make([]string, 0)
s := strings.TrimSpace(str)
if s == "" {
return origins
}
hosts := strings.FieldsFunc(s, func(r rune) bool {
return r == ',' || r == ' '
})
for _, host := range hosts {
h := strings.TrimSpace(host)
if h != "" {
if h == "*" {
return []string{"*"}
}
if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
h = "http://" + h
}
u, err := url.Parse(h)
if err != nil || u.Hostname() == "" || strings.Contains(h, "*") {
logger.Log().Warnf("[cors] invalid CORS origin \"%s\", ignoring", h)
continue
}
origins = append(origins, u.Hostname())
}
}
sort.Strings(origins)
return origins
}

View File

@@ -32,21 +32,23 @@ import (
)
var (
// AccessControlAllowOrigin CORS policy
AccessControlAllowOrigin string
// htmlPreviewRouteRe is a regexp to match the HTML preview route
htmlPreviewRouteRe *regexp.Regexp
)
// Listen will start the httpd
func Listen() {
setCORSOrigins()
isReady := &atomic.Value{}
isReady.Store(false)
stats.Track()
websockets.MessageHub = websockets.NewHub()
// set allowed websocket origins from configuration
// websockets.SetAllowedOrigins(AccessControlAllowWSOrigins)
go websockets.MessageHub.Run()
go pop3.Run()
@@ -287,9 +289,12 @@ func middleWareFunc(fn http.HandlerFunc) http.HandlerFunc {
htmlPreviewRouteRe = regexp.MustCompile(`^` + regexp.QuoteMeta(config.Webroot) + `view/[a-zA-Z0-9]+\.html$`)
}
if AccessControlAllowOrigin != "" &&
(strings.HasPrefix(r.RequestURI, config.Webroot+"api/") || htmlPreviewRouteRe.MatchString(r.RequestURI)) {
w.Header().Set("Access-Control-Allow-Origin", AccessControlAllowOrigin)
if strings.HasPrefix(r.RequestURI, config.Webroot+"api/") || htmlPreviewRouteRe.MatchString(r.RequestURI) {
if allowed := corsOriginAccessControl(r); !allowed {
http.Error(w, "Unauthorised.", http.StatusForbidden)
return
}
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, PUT, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "*")
}
@@ -331,6 +336,12 @@ func addSlashToWebroot(w http.ResponseWriter, r *http.Request) {
// Websocket to broadcast changes
func apiWebsocket(w http.ResponseWriter, r *http.Request) {
if allowed := corsOriginAccessControl(r); !allowed {
http.Error(w, "Unauthorised.", http.StatusForbidden)
return
}
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
w.Header().Set("Access-Control-Allow-Headers", "*")
websockets.ServeWs(websockets.MessageHub, w, r)
storage.BroadcastMailboxStats()
}

View File

@@ -35,6 +35,10 @@ var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
EnableCompression: true,
CheckOrigin: func(_ *http.Request) bool {
// origin is checked via server.go's CORS settings
return true
},
}
// Client is a middleman between the websocket connection and the hub.