Chore: Add message dump max-message-size flag and refactor message handling

This commit is contained in:
Ralph Slooten
2026-05-14 15:58:47 +12:00
parent 5ec074208c
commit b7e4146dbf
2 changed files with 131 additions and 61 deletions

View File

@@ -30,7 +30,8 @@ func init() {
dumpCmd.Flags().SortFlags = false
dumpCmd.Flags().StringVar(&config.Database, "database", config.Database, "Dump messages directly from a database file")
dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)")
dumpCmd.Flags().StringVar(&dump.URL, "http", dump.URL, "Dump messages via HTTP API (base URL of running Mailpit instance)")
dumpCmd.Flags().IntVar(&config.MaxMessageSize, "max-message-size", config.MaxMessageSize, "Maximum message size in MB (0 = unlimited)")
dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)")
dumpCmd.Flags().BoolVarP(&logger.VerboseLogging, "verbose", "v", logger.VerboseLogging, "Verbose logging")
}

View File

@@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
@@ -20,17 +21,16 @@ import (
)
// httpClient bounds each remote request so a slow or hostile --http endpoint
// cannot hang the dump indefinitely. Body size is independently capped by
// maxRawSize / maxSummarySize via io.LimitReader.
// cannot hang the dump indefinitely.
var httpClient = &http.Client{Timeout: time.Minute}
// maxRawSize caps the bytes read per remote message to prevent a hostile
// server from exhausting local disk via an unbounded response body.
const maxRawSize = 50 * 1024 * 1024 // 50 MiB
// maxSummarySize caps the bytes read from the remote messages-summary endpoint
// to prevent a hostile server from exhausting memory via an unbounded response.
const maxSummarySize = 1000 * 1024 * 1024 // 1000 MiB
const maxSummarySize = 20 * 1024 * 1024 // 20 MiB
// pageSize is the per-request limit when paging through the remote messages
// summary endpoint.
const pageSize = 10000
var (
linkRe = regexp.MustCompile(`(?i)^https?:\/\/`)
@@ -46,7 +46,9 @@ var (
// URL is the base URL of a remove Mailpit instance
URL string
summary = []storage.MessageSummary{}
dumpIDs = make(map[string]struct {
Timestamp time.Time
})
)
// Sync will sync all messages from the specified database or API to the specified output directory
@@ -88,67 +90,117 @@ func loadIDs() error {
if base != "" {
// remote
logger.Log().Debugf("Fetching messages summary from %s", base)
res, err := httpClient.Get(base + "api/v1/messages?limit=0")
if err != nil {
return err
start := 0
var total uint64
for {
data, err := fetchSummaryPage(start)
if err != nil {
return err
}
if start == 0 {
total = data.Total
}
for _, m := range data.Messages {
dumpIDs[m.ID] = struct {
Timestamp time.Time
}{Timestamp: m.Created}
}
logger.Log().Debugf("Fetched messages summary page start=%d size=%d (%d/%d)", start, len(data.Messages), len(dumpIDs), total)
// stop on empty page to guard against stale/inconsistent Total
if len(data.Messages) == 0 {
break
}
if uint64(len(dumpIDs)) >= total {
break
}
start += pageSize
}
if res.StatusCode != http.StatusOK {
res.Body.Close()
return errors.New("error fetching messages summary: HTTP " + res.Status)
}
body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1))
if err != nil {
return err
}
if int64(len(body)) > maxSummarySize {
return errors.New("messages summary exceeds size cap")
}
var data apiv1.MessagesSummary
if err := json.Unmarshal(body, &data); err != nil {
return err
}
summary = data.Messages
} else {
// make sure the database isn't pruned while open
config.MaxMessages = 0
var err error
// local database
if err = storage.InitDB(); err != nil {
if err := storage.InitDB(); err != nil {
return err
}
logger.Log().Debugf("Fetching messages summary from %s", config.Database)
summary, err = storage.List(0, 0, 0)
if err != nil {
return err
start := 0
for {
page, err := storage.List(start, 0, pageSize)
if err != nil {
return err
}
for _, m := range page {
dumpIDs[m.ID] = struct {
Timestamp time.Time
}{Timestamp: m.Created}
}
if len(page) < pageSize {
break
}
start += pageSize
}
}
if len(summary) == 0 {
if len(dumpIDs) == 0 {
return errors.New("no messages found")
}
return nil
}
// fetchSummaryPage fetches a single page of the remote messages summary,
// starting at the given offset.
func fetchSummaryPage(start int) (*apiv1.MessagesSummary, error) {
url := base + "api/v1/messages?limit=" + strconv.Itoa(pageSize) + "&start=" + strconv.Itoa(start)
res, err := httpClient.Get(url)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, errors.New("error fetching messages summary: HTTP " + res.Status)
}
body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1))
if err != nil {
return nil, err
}
if int64(len(body)) > maxSummarySize {
return nil, errors.New("messages summary exceeds size cap")
}
var data apiv1.MessagesSummary
if err := json.Unmarshal(body, &data); err != nil {
return nil, err
}
return &data, nil
}
func saveMessages() error {
for _, m := range summary {
if !idRe.MatchString(m.ID) {
logger.Log().Errorf("skipping message with invalid ID: %q", m.ID)
for id, m := range dumpIDs {
if !idRe.MatchString(id) {
logger.Log().Errorf("skipping message with invalid ID: %q", id)
continue
}
out := filepath.Join(outDir, m.ID+".eml")
out := filepath.Join(outDir, id+".eml")
// skip if message exists
if tools.IsFile(out) {
@@ -157,49 +209,66 @@ func saveMessages() error {
var b []byte
limit := int64(config.MaxMessageSize) * 1024 * 1024
if base != "" {
res, err := httpClient.Get(base + "api/v1/message/" + m.ID + "/raw")
res, err := httpClient.Get(base + "api/v1/message/" + id + "/raw")
if err != nil {
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
continue
}
if res.StatusCode != http.StatusOK {
res.Body.Close()
logger.Log().Errorf("error fetching message %s: HTTP %d", m.ID, res.StatusCode)
logger.Log().Errorf("error fetching message %s: HTTP %d", id, res.StatusCode)
continue
}
b, err = io.ReadAll(io.LimitReader(res.Body, maxRawSize+1))
res.Body.Close()
if config.MaxMessageSize > 0 {
b, err = io.ReadAll(io.LimitReader(res.Body, limit+1))
res.Body.Close()
if err != nil {
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
continue
}
if err != nil {
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
continue
}
if len(b) > maxRawSize {
logger.Log().Errorf("message %s exceeds size cap (%d bytes), skipping", m.ID, maxRawSize)
continue
if int64(len(b)) > limit {
logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize)
continue
}
} else {
b, err = io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
continue
}
}
} else {
var err error
b, err = storage.GetMessageRaw(m.ID)
b, err = storage.GetMessageRaw(id)
if err != nil {
logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error())
logger.Log().Errorf("error fetching message %s: %s", id, err.Error())
continue
}
if config.MaxMessageSize > 0 && int64(len(b)) > limit {
logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize)
continue
}
}
if err := os.WriteFile(out, b, 0644); /* #nosec */ err != nil {
logger.Log().Errorf("error writing message %s: %s", m.ID, err.Error())
logger.Log().Errorf("error writing message %s: %s", id, err.Error())
continue
}
_ = os.Chtimes(out, m.Created, m.Created)
_ = os.Chtimes(out, m.Timestamp, m.Timestamp)
logger.Log().Debugf("Saved message %s to %s", m.ID, out)
logger.Log().Debugf("Saved message %s to %s", id, out)
}
return nil