mirror of
https://github.com/axllent/mailpit.git
synced 2026-03-03 02:17:01 +00:00
Chore: Add support for multi-origin CORS settings and apply to events websocket (#630)
This commit is contained in:
127
server/cors.go
Normal file
127
server/cors.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/axllent/mailpit/internal/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// AccessControlAllowOrigin CORS policy - set with flags/env
|
||||
AccessControlAllowOrigin string
|
||||
|
||||
// CorsAllowOrigins are optional allowed origins by hostname, set via setCORSOrigins().
|
||||
corsAllowOrigins = make(map[string]bool)
|
||||
)
|
||||
|
||||
// equalASCIIFold reports whether s and t, interpreted as UTF-8 strings, are equal
|
||||
// under Unicode case folding, ignoring any difference in length.
|
||||
func asciiFoldString(s string) string {
|
||||
b := make([]byte, len(s))
|
||||
for i := range s {
|
||||
b[i] = toLowerASCIIFold(s[i])
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// toLowerASCIIFold returns the Unicode case-folded equivalent of the ASCII character c.
|
||||
// It is equivalent to the Unicode 13.0.0 function foldCase(c, CaseFoldingMapping).
|
||||
func toLowerASCIIFold(c byte) byte {
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
return c + 'a' - 'A'
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// CorsOriginAccessControl checks if the request origin is allowed based on the configured CORS origins.
|
||||
func corsOriginAccessControl(r *http.Request) bool {
|
||||
origin := r.Header["Origin"]
|
||||
|
||||
if len(origin) != 0 {
|
||||
u, err := url.Parse(origin[0])
|
||||
if err != nil {
|
||||
logger.Log().Errorf("CORS origin parse error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
_, allAllowed := corsAllowOrigins["*"]
|
||||
// allow same origin || is "*" is defined as an origin
|
||||
if asciiFoldString(u.Host) == asciiFoldString(r.Host) || allAllowed {
|
||||
return true
|
||||
}
|
||||
|
||||
originHostFold := asciiFoldString(u.Hostname())
|
||||
if corsAllowOrigins[originHostFold] {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// SetCORSOrigins sets the allowed CORS origins from a comma-separated string.
|
||||
// It does not consider port or protocol, only the hostname.
|
||||
func setCORSOrigins() {
|
||||
corsAllowOrigins = make(map[string]bool)
|
||||
|
||||
hosts := extractOrigins(AccessControlAllowOrigin)
|
||||
for _, host := range hosts {
|
||||
corsAllowOrigins[asciiFoldString(host)] = true
|
||||
}
|
||||
|
||||
if _, wildCard := corsAllowOrigins["*"]; wildCard {
|
||||
// reset to just wildcard
|
||||
corsAllowOrigins = make(map[string]bool)
|
||||
corsAllowOrigins["*"] = true
|
||||
logger.Log().Info("[cors] all origins are allowed due to wildcard \"*\"")
|
||||
} else {
|
||||
keys := make([]string, 0)
|
||||
for k := range corsAllowOrigins {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
logger.Log().Infof("[cors] allowed API origins: %v", strings.Join(keys, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
// extractOrigins extracts and returns a sorted list of origins from a comma-separated string.
|
||||
func extractOrigins(str string) []string {
|
||||
origins := make([]string, 0)
|
||||
s := strings.TrimSpace(str)
|
||||
if s == "" {
|
||||
return origins
|
||||
}
|
||||
|
||||
hosts := strings.FieldsFunc(s, func(r rune) bool {
|
||||
return r == ',' || r == ' '
|
||||
})
|
||||
|
||||
for _, host := range hosts {
|
||||
h := strings.TrimSpace(host)
|
||||
if h != "" {
|
||||
if h == "*" {
|
||||
return []string{"*"}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
|
||||
h = "http://" + h
|
||||
}
|
||||
|
||||
u, err := url.Parse(h)
|
||||
if err != nil || u.Hostname() == "" || strings.Contains(h, "*") {
|
||||
logger.Log().Warnf("[cors] invalid CORS origin \"%s\", ignoring", h)
|
||||
continue
|
||||
}
|
||||
|
||||
origins = append(origins, u.Hostname())
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(origins)
|
||||
|
||||
return origins
|
||||
}
|
||||
@@ -32,21 +32,23 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// AccessControlAllowOrigin CORS policy
|
||||
AccessControlAllowOrigin string
|
||||
|
||||
// htmlPreviewRouteRe is a regexp to match the HTML preview route
|
||||
htmlPreviewRouteRe *regexp.Regexp
|
||||
)
|
||||
|
||||
// Listen will start the httpd
|
||||
func Listen() {
|
||||
setCORSOrigins()
|
||||
|
||||
isReady := &atomic.Value{}
|
||||
isReady.Store(false)
|
||||
stats.Track()
|
||||
|
||||
websockets.MessageHub = websockets.NewHub()
|
||||
|
||||
// set allowed websocket origins from configuration
|
||||
// websockets.SetAllowedOrigins(AccessControlAllowWSOrigins)
|
||||
|
||||
go websockets.MessageHub.Run()
|
||||
|
||||
go pop3.Run()
|
||||
@@ -287,9 +289,12 @@ func middleWareFunc(fn http.HandlerFunc) http.HandlerFunc {
|
||||
htmlPreviewRouteRe = regexp.MustCompile(`^` + regexp.QuoteMeta(config.Webroot) + `view/[a-zA-Z0-9]+\.html$`)
|
||||
}
|
||||
|
||||
if AccessControlAllowOrigin != "" &&
|
||||
(strings.HasPrefix(r.RequestURI, config.Webroot+"api/") || htmlPreviewRouteRe.MatchString(r.RequestURI)) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", AccessControlAllowOrigin)
|
||||
if strings.HasPrefix(r.RequestURI, config.Webroot+"api/") || htmlPreviewRouteRe.MatchString(r.RequestURI) {
|
||||
if allowed := corsOriginAccessControl(r); !allowed {
|
||||
http.Error(w, "Unauthorised.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, PUT, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
}
|
||||
@@ -331,6 +336,12 @@ func addSlashToWebroot(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Websocket to broadcast changes
|
||||
func apiWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
if allowed := corsOriginAccessControl(r); !allowed {
|
||||
http.Error(w, "Unauthorised.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
w.Header().Set("Access-Control-Allow-Headers", "*")
|
||||
websockets.ServeWs(websockets.MessageHub, w, r)
|
||||
storage.BroadcastMailboxStats()
|
||||
}
|
||||
|
||||
@@ -35,6 +35,10 @@ var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
EnableCompression: true,
|
||||
CheckOrigin: func(_ *http.Request) bool {
|
||||
// origin is checked via server.go's CORS settings
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// Client is a middleman between the websocket connection and the hub.
|
||||
|
||||
Reference in New Issue
Block a user