mirror of
https://github.com/axllent/mailpit.git
synced 2026-06-27 22:46:09 +00:00
Chore: Add message dump max-message-size flag and refactor message handling
This commit is contained in:
@@ -30,7 +30,8 @@ func init() {
|
||||
dumpCmd.Flags().SortFlags = false
|
||||
|
||||
dumpCmd.Flags().StringVar(&config.Database, "database", config.Database, "Dump messages directly from a database file")
|
||||
dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)")
|
||||
dumpCmd.Flags().StringVar(&dump.URL, "http", dump.URL, "Dump messages via HTTP API (base URL of running Mailpit instance)")
|
||||
dumpCmd.Flags().IntVar(&config.MaxMessageSize, "max-message-size", config.MaxMessageSize, "Maximum message size in MB (0 = unlimited)")
|
||||
dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)")
|
||||
dumpCmd.Flags().BoolVarP(&logger.VerboseLogging, "verbose", "v", logger.VerboseLogging, "Verbose logging")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,17 +21,16 @@ import (
|
||||
)
|
||||
|
||||
// httpClient bounds each remote request so a slow or hostile --http endpoint
|
||||
// cannot hang the dump indefinitely. Body size is independently capped by
|
||||
// maxRawSize / maxSummarySize via io.LimitReader.
|
||||
// cannot hang the dump indefinitely.
|
||||
var httpClient = &http.Client{Timeout: time.Minute}
|
||||
|
||||
// maxRawSize caps the bytes read per remote message to prevent a hostile
|
||||
// server from exhausting local disk via an unbounded response body.
|
||||
const maxRawSize = 50 * 1024 * 1024 // 50 MiB
|
||||
|
||||
// maxSummarySize caps the bytes read from the remote messages-summary endpoint
|
||||
// to prevent a hostile server from exhausting memory via an unbounded response.
|
||||
const maxSummarySize = 1000 * 1024 * 1024 // 1000 MiB
|
||||
const maxSummarySize = 20 * 1024 * 1024 // 20 MiB
|
||||
|
||||
// pageSize is the per-request limit when paging through the remote messages
|
||||
// summary endpoint.
|
||||
const pageSize = 10000
|
||||
|
||||
var (
|
||||
linkRe = regexp.MustCompile(`(?i)^https?:\/\/`)
|
||||
@@ -46,7 +46,9 @@ var (
|
||||
// URL is the base URL of a remove Mailpit instance
|
||||
URL string
|
||||
|
||||
summary = []storage.MessageSummary{}
|
||||
dumpIDs = make(map[string]struct {
|
||||
Timestamp time.Time
|
||||
})
|
||||
)
|
||||
|
||||
// Sync will sync all messages from the specified database or API to the specified output directory
|
||||
@@ -88,67 +90,117 @@ func loadIDs() error {
|
||||
if base != "" {
|
||||
// remote
|
||||
logger.Log().Debugf("Fetching messages summary from %s", base)
|
||||
res, err := httpClient.Get(base + "api/v1/messages?limit=0")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
start := 0
|
||||
var total uint64
|
||||
for {
|
||||
data, err := fetchSummaryPage(start)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if start == 0 {
|
||||
total = data.Total
|
||||
}
|
||||
|
||||
for _, m := range data.Messages {
|
||||
dumpIDs[m.ID] = struct {
|
||||
Timestamp time.Time
|
||||
}{Timestamp: m.Created}
|
||||
}
|
||||
|
||||
logger.Log().Debugf("Fetched messages summary page start=%d size=%d (%d/%d)", start, len(data.Messages), len(dumpIDs), total)
|
||||
|
||||
// stop on empty page to guard against stale/inconsistent Total
|
||||
if len(data.Messages) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if uint64(len(dumpIDs)) >= total {
|
||||
break
|
||||
}
|
||||
|
||||
start += pageSize
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
res.Body.Close()
|
||||
return errors.New("error fetching messages summary: HTTP " + res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if int64(len(body)) > maxSummarySize {
|
||||
return errors.New("messages summary exceeds size cap")
|
||||
}
|
||||
|
||||
var data apiv1.MessagesSummary
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
summary = data.Messages
|
||||
|
||||
} else {
|
||||
// make sure the database isn't pruned while open
|
||||
config.MaxMessages = 0
|
||||
|
||||
var err error
|
||||
// local database
|
||||
if err = storage.InitDB(); err != nil {
|
||||
if err := storage.InitDB(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log().Debugf("Fetching messages summary from %s", config.Database)
|
||||
|
||||
summary, err = storage.List(0, 0, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
start := 0
|
||||
for {
|
||||
page, err := storage.List(start, 0, pageSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range page {
|
||||
dumpIDs[m.ID] = struct {
|
||||
Timestamp time.Time
|
||||
}{Timestamp: m.Created}
|
||||
}
|
||||
|
||||
if len(page) < pageSize {
|
||||
break
|
||||
}
|
||||
|
||||
start += pageSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(summary) == 0 {
|
||||
if len(dumpIDs) == 0 {
|
||||
return errors.New("no messages found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchSummaryPage fetches a single page of the remote messages summary,
|
||||
// starting at the given offset.
|
||||
func fetchSummaryPage(start int) (*apiv1.MessagesSummary, error) {
|
||||
url := base + "api/v1/messages?limit=" + strconv.Itoa(pageSize) + "&start=" + strconv.Itoa(start)
|
||||
res, err := httpClient.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, errors.New("error fetching messages summary: HTTP " + res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if int64(len(body)) > maxSummarySize {
|
||||
return nil, errors.New("messages summary exceeds size cap")
|
||||
}
|
||||
|
||||
var data apiv1.MessagesSummary
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func saveMessages() error {
|
||||
for _, m := range summary {
|
||||
if !idRe.MatchString(m.ID) {
|
||||
logger.Log().Errorf("skipping message with invalid ID: %q", m.ID)
|
||||
for id, m := range dumpIDs {
|
||||
if !idRe.MatchString(id) {
|
||||
logger.Log().Errorf("skipping message with invalid ID: %q", id)
|
||||
continue
|
||||
}
|
||||
|
||||
out := filepath.Join(outDir, m.ID+".eml")
|
||||
out := filepath.Join(outDir, id+".eml")
|
||||
|
||||
// skip if message exists
|
||||
if tools.IsFile(out) {
|
||||
@@ -157,49 +209,66 @@ func saveMessages() error {
|
||||
|
||||
var b []byte
|
||||
|
||||
limit := int64(config.MaxMessageSize) * 1024 * 1024
|
||||
|
||||
if base != "" {
|
||||
res, err := httpClient.Get(base + "api/v1/message/" + m.ID + "/raw")
|
||||
res, err := httpClient.Get(base + "api/v1/message/" + id + "/raw")
|
||||
|
||||
if err != nil {
|
||||
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
|
||||
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
res.Body.Close()
|
||||
logger.Log().Errorf("error fetching message %s: HTTP %d", m.ID, res.StatusCode)
|
||||
logger.Log().Errorf("error fetching message %s: HTTP %d", id, res.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
b, err = io.ReadAll(io.LimitReader(res.Body, maxRawSize+1))
|
||||
res.Body.Close()
|
||||
if config.MaxMessageSize > 0 {
|
||||
b, err = io.ReadAll(io.LimitReader(res.Body, limit+1))
|
||||
res.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if len(b) > maxRawSize {
|
||||
logger.Log().Errorf("message %s exceeds size cap (%d bytes), skipping", m.ID, maxRawSize)
|
||||
continue
|
||||
if int64(len(b)) > limit {
|
||||
logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
b, err = io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
b, err = storage.GetMessageRaw(m.ID)
|
||||
b, err = storage.GetMessageRaw(id)
|
||||
if err != nil {
|
||||
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
|
||||
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
if config.MaxMessageSize > 0 && int64(len(b)) > limit {
|
||||
logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.WriteFile(out, b, 0644); /* #nosec */ err != nil {
|
||||
logger.Log().Errorf("error writing message %s: %s", m.ID, err.Error())
|
||||
logger.Log().Errorf("error writing message %s: %s", id, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
_ = os.Chtimes(out, m.Created, m.Created)
|
||||
_ = os.Chtimes(out, m.Timestamp, m.Timestamp)
|
||||
|
||||
logger.Log().Debugf("Saved message %s to %s", m.ID, out)
|
||||
logger.Log().Debugf("Saved message %s to %s", id, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user