diff --git a/internal/storage/tags.go b/internal/storage/tags.go index 464d75b..a2d3c19 100644 --- a/internal/storage/tags.go +++ b/internal/storage/tags.go @@ -8,7 +8,6 @@ import ( "regexp" "sort" "strings" - "sync" "github.com/axllent/mailpit/config" "github.com/axllent/mailpit/internal/logger" @@ -19,7 +18,6 @@ import ( var ( addressPlusRe = regexp.MustCompile(`(?U)^(.*){1,}\+(.*)@`) - addTagMutex sync.RWMutex ) // SetMessageTags will set the tags for a given database ID, removing any not in the array @@ -91,57 +89,43 @@ func SetMessageTags(id string, tags []string) ([]string, error) { // AddMessageTag adds a tag to a message func addMessageTag(id, name string) (string, error) { - // prevent two identical tags being added at the same time - addTagMutex.Lock() - - var tagID int - var foundName sql.NullString - - q := sqlf.From(tenant("tags")). - Select("ID").To(&tagID). - Select("Name").To(&foundName). - Where("Name = ?", name) - - // if tag exists - add tag to message - if err := q.QueryRowAndClose(context.TODO(), db); err == nil { - addTagMutex.Unlock() - // check message does not already have this tag - var exists int - - if err := sqlf.From(tenant("message_tags")). - Select("COUNT(ID)").To(&exists). - Where("ID = ?", id). - Where("TagID = ?", tagID). - QueryRowAndClose(context.Background(), db); err != nil { - return "", err - } - if exists > 0 { - // already exists - return foundName.String, nil - } - - logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id) - - _, err := sqlf.InsertInto(tenant("message_tags")). - Set("ID", id). - Set("TagID", tagID). - ExecAndClose(context.TODO(), db) - - return foundName.String, err - } - - // new tag, add to the database - if _, err := sqlf.InsertInto(tenant("tags")). - Set("Name", name). - ExecAndClose(context.TODO(), db); err != nil { - addTagMutex.Unlock() + // Ensure the tag row exists; the UNIQUE index on Name makes concurrent inserts safe + if _, err := db.Exec(fmt.Sprintf(`INSERT OR IGNORE INTO %s (Name) VALUES (?)`, tenant("tags")), name); err != nil { // #nosec return name, err } - addTagMutex.Unlock() + var tagID int + var foundName string - // add tag to the message - return addMessageTag(id, name) + if err := sqlf.From(tenant("tags")). + Select("ID").To(&tagID). + Select("Name").To(&foundName). + Where("Name = ?", name). + QueryRowAndClose(context.TODO(), db); err != nil { + return name, err + } + + // Check message does not already have this tag + var exists int + if err := sqlf.From(tenant("message_tags")). + Select("COUNT(ID)").To(&exists). + Where("ID = ?", id). + Where("TagID = ?", tagID). + QueryRowAndClose(context.Background(), db); err != nil { + return "", err + } + if exists > 0 { + return foundName, nil + } + + logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id) + + _, err := sqlf.InsertInto(tenant("message_tags")). + Set("ID", id). + Set("TagID", tagID). + ExecAndClose(context.TODO(), db) + + return foundName, err } // deleteMessageTags deletes multiple tags from a message in a single query