diff --git a/cmd/dump.go b/cmd/dump.go index a487a15..9954c99 100644 --- a/cmd/dump.go +++ b/cmd/dump.go @@ -30,7 +30,8 @@ func init() { dumpCmd.Flags().SortFlags = false dumpCmd.Flags().StringVar(&config.Database, "database", config.Database, "Dump messages directly from a database file") - dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)") dumpCmd.Flags().StringVar(&dump.URL, "http", dump.URL, "Dump messages via HTTP API (base URL of running Mailpit instance)") + dumpCmd.Flags().IntVar(&config.MaxMessageSize, "max-message-size", config.MaxMessageSize, "Maximum message size in MB (0 = unlimited)") + dumpCmd.Flags().StringVar(&config.TenantID, "tenant-id", config.TenantID, "Database tenant ID to isolate data (optional)") dumpCmd.Flags().BoolVarP(&logger.VerboseLogging, "verbose", "v", logger.VerboseLogging, "Verbose logging") } diff --git a/internal/dump/dump.go b/internal/dump/dump.go index 4e8b609..c9258a2 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "regexp" + "strconv" "strings" "time" @@ -20,17 +21,16 @@ import ( ) // httpClient bounds each remote request so a slow or hostile --http endpoint -// cannot hang the dump indefinitely. Body size is independently capped by -// maxRawSize / maxSummarySize via io.LimitReader. +// cannot hang the dump indefinitely. var httpClient = &http.Client{Timeout: time.Minute} -// maxRawSize caps the bytes read per remote message to prevent a hostile -// server from exhausting local disk via an unbounded response body. -const maxRawSize = 50 * 1024 * 1024 // 50 MiB - // maxSummarySize caps the bytes read from the remote messages-summary endpoint // to prevent a hostile server from exhausting memory via an unbounded response. -const maxSummarySize = 1000 * 1024 * 1024 // 1000 MiB +const maxSummarySize = 20 * 1024 * 1024 // 20 MiB + +// pageSize is the per-request limit when paging through the remote messages +// summary endpoint. +const pageSize = 10000 var ( linkRe = regexp.MustCompile(`(?i)^https?:\/\/`) @@ -46,7 +46,9 @@ var ( // URL is the base URL of a remove Mailpit instance URL string - summary = []storage.MessageSummary{} + dumpIDs = make(map[string]struct { + Timestamp time.Time + }) ) // Sync will sync all messages from the specified database or API to the specified output directory @@ -88,67 +90,117 @@ func loadIDs() error { if base != "" { // remote logger.Log().Debugf("Fetching messages summary from %s", base) - res, err := httpClient.Get(base + "api/v1/messages?limit=0") - if err != nil { - return err + start := 0 + var total uint64 + for { + data, err := fetchSummaryPage(start) + if err != nil { + return err + } + + if start == 0 { + total = data.Total + } + + for _, m := range data.Messages { + dumpIDs[m.ID] = struct { + Timestamp time.Time + }{Timestamp: m.Created} + } + + logger.Log().Debugf("Fetched messages summary page start=%d size=%d (%d/%d)", start, len(data.Messages), len(dumpIDs), total) + + // stop on empty page to guard against stale/inconsistent Total + if len(data.Messages) == 0 { + break + } + + if uint64(len(dumpIDs)) >= total { + break + } + + start += pageSize } - if res.StatusCode != http.StatusOK { - res.Body.Close() - return errors.New("error fetching messages summary: HTTP " + res.Status) - } - - body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1)) - - if err != nil { - return err - } - - if int64(len(body)) > maxSummarySize { - return errors.New("messages summary exceeds size cap") - } - - var data apiv1.MessagesSummary - if err := json.Unmarshal(body, &data); err != nil { - return err - } - - summary = data.Messages - } else { // make sure the database isn't pruned while open config.MaxMessages = 0 - var err error // local database - if err = storage.InitDB(); err != nil { + if err := storage.InitDB(); err != nil { return err } logger.Log().Debugf("Fetching messages summary from %s", config.Database) - summary, err = storage.List(0, 0, 0) - if err != nil { - return err + start := 0 + for { + page, err := storage.List(start, 0, pageSize) + if err != nil { + return err + } + + for _, m := range page { + dumpIDs[m.ID] = struct { + Timestamp time.Time + }{Timestamp: m.Created} + } + + if len(page) < pageSize { + break + } + + start += pageSize } } - if len(summary) == 0 { + if len(dumpIDs) == 0 { return errors.New("no messages found") } return nil } +// fetchSummaryPage fetches a single page of the remote messages summary, +// starting at the given offset. +func fetchSummaryPage(start int) (*apiv1.MessagesSummary, error) { + url := base + "api/v1/messages?limit=" + strconv.Itoa(pageSize) + "&start=" + strconv.Itoa(start) + res, err := httpClient.Get(url) + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return nil, errors.New("error fetching messages summary: HTTP " + res.Status) + } + + body, err := io.ReadAll(io.LimitReader(res.Body, maxSummarySize+1)) + if err != nil { + return nil, err + } + + if int64(len(body)) > maxSummarySize { + return nil, errors.New("messages summary exceeds size cap") + } + + var data apiv1.MessagesSummary + if err := json.Unmarshal(body, &data); err != nil { + return nil, err + } + + return &data, nil +} + func saveMessages() error { - for _, m := range summary { - if !idRe.MatchString(m.ID) { - logger.Log().Errorf("skipping message with invalid ID: %q", m.ID) + for id, m := range dumpIDs { + if !idRe.MatchString(id) { + logger.Log().Errorf("skipping message with invalid ID: %q", id) continue } - out := filepath.Join(outDir, m.ID+".eml") + out := filepath.Join(outDir, id+".eml") // skip if message exists if tools.IsFile(out) { @@ -157,49 +209,66 @@ func saveMessages() error { var b []byte + limit := int64(config.MaxMessageSize) * 1024 * 1024 + if base != "" { - res, err := httpClient.Get(base + "api/v1/message/" + m.ID + "/raw") + res, err := httpClient.Get(base + "api/v1/message/" + id + "/raw") if err != nil { - logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error()) + logger.Log().Errorf("error fetching message %s: %s", id, err.Error()) continue } if res.StatusCode != http.StatusOK { res.Body.Close() - logger.Log().Errorf("error fetching message %s: HTTP %d", m.ID, res.StatusCode) + logger.Log().Errorf("error fetching message %s: HTTP %d", id, res.StatusCode) continue } - b, err = io.ReadAll(io.LimitReader(res.Body, maxRawSize+1)) - res.Body.Close() + if config.MaxMessageSize > 0 { + b, err = io.ReadAll(io.LimitReader(res.Body, limit+1)) + res.Body.Close() - if err != nil { - logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error()) - continue - } + if err != nil { + logger.Log().Errorf("error fetching message %s: %s", id, err.Error()) + continue + } - if len(b) > maxRawSize { - logger.Log().Errorf("message %s exceeds size cap (%d bytes), skipping", m.ID, maxRawSize) - continue + if int64(len(b)) > limit { + logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize) + continue + } + } else { + b, err = io.ReadAll(res.Body) + res.Body.Close() + + if err != nil { + logger.Log().Errorf("error fetching message %s: %s", id, err.Error()) + continue + } } } else { var err error - b, err = storage.GetMessageRaw(m.ID) + b, err = storage.GetMessageRaw(id) if err != nil { - logger.Log().Errorf("error fetching message %s: %s", m.ID, err.Error()) + logger.Log().Errorf("error fetching message %s: %s", id, err.Error()) + continue + } + + if config.MaxMessageSize > 0 && int64(len(b)) > limit { + logger.Log().Warnf("message %s exceeds %d MiB size cap, skipping", id, config.MaxMessageSize) continue } } if err := os.WriteFile(out, b, 0644); /* #nosec */ err != nil { - logger.Log().Errorf("error writing message %s: %s", m.ID, err.Error()) + logger.Log().Errorf("error writing message %s: %s", id, err.Error()) continue } - _ = os.Chtimes(out, m.Created, m.Created) + _ = os.Chtimes(out, m.Timestamp, m.Timestamp) - logger.Log().Debugf("Saved message %s to %s", m.ID, out) + logger.Log().Debugf("Saved message %s to %s", id, out) } return nil