Files
cloudreve/inventory/share.go
WittF e4e6beb52d feat(share): add Open Graph preview for social media crawlers (#3234)
* feat(share): add Open Graph preview for social media crawlers

Add middleware to intercept social media bot requests and return
OG meta tags for share links, enabling rich previews on platforms
like Facebook, Twitter, Discord, etc.

* feat(og): reuse existing share service and show thumbnail if possible

* resolve comments

---------

Co-authored-by: Aaron Liu <abslant.liu@gmail.com>
2026-01-25 12:48:14 +08:00

428 lines
12 KiB
Go

package inventory
import (
"context"
"fmt"
"time"
"github.com/cloudreve/Cloudreve/v4/inventory/types"
"entgo.io/ent/dialect/sql"
"github.com/cloudreve/Cloudreve/v4/ent"
"github.com/cloudreve/Cloudreve/v4/ent/file"
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
"github.com/cloudreve/Cloudreve/v4/ent/share"
"github.com/cloudreve/Cloudreve/v4/ent/user"
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
"github.com/samber/lo"
)
type (
// Ctx keys for eager loading options.
LoadShareFile struct{}
LoadShareUser struct{}
)
var (
ErrShareLinkExpired = fmt.Errorf("share link expired")
ErrOwnerInactive = fmt.Errorf("owner is inactive")
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
)
type (
ShareClient interface {
TxOperator
// GetByIDs returns the shares with given ids.
GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error)
// GetByID returns the share with given id.
GetByID(ctx context.Context, id int) (*ent.Share, error)
// GetByIDUser returns the share with given id and user id.
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
// GetByHashID returns the share with given hash id.
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
// Upsert creates or update a new share record.
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
// Viewed increase the view count of the share.
Viewed(ctx context.Context, share *ent.Share) error
// Downloaded increase the download count of the share.
Downloaded(ctx context.Context, share *ent.Share) error
// Delete deletes the share.
Delete(ctx context.Context, shareId int) error
// List returns a list of shares with the given args.
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
// CountByTimeRange counts the number of shares created in the given time range.
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
// DeleteBatch deletes the shares with the given ids.
DeleteBatch(ctx context.Context, shareIds []int) error
}
CreateShareParams struct {
Existed *ent.Share
Password string
RemainDownloads int
Expires *time.Time
OwnerID int
FileID int
Props *types.ShareProps
}
ListShareArgs struct {
*PaginationArgs
UserID int
FileID int
PublicOnly bool
ShareIDs []int
}
ListShareResult struct {
*PaginationResults
Shares []*ent.Share
}
)
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
return &shareClient{
client: client,
hasher: hasher,
maxSQlParam: sqlParamLimit(dbType),
}
}
type shareClient struct {
maxSQlParam int
client *ent.Client
hasher hashid.Encoder
}
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
}
func (c *shareClient) GetClient() *ent.Client {
return c.client
}
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
if start == nil || end == nil {
return c.client.Share.Query().Count(ctx)
}
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
}
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
if params.Existed != nil {
createQuery := c.client.Share.
UpdateOne(params.Existed)
if params.RemainDownloads > 0 {
createQuery.SetRemainDownloads(params.RemainDownloads)
} else {
createQuery.ClearRemainDownloads()
}
if params.Expires != nil {
createQuery.SetNillableExpires(params.Expires)
} else {
createQuery.ClearExpires()
}
if params.Props != nil {
createQuery.SetProps(params.Props)
}
return createQuery.Save(ctx)
}
query := c.client.Share.
Create().
SetUserID(params.OwnerID).
SetFileID(params.FileID)
if params.Password != "" {
query.SetPassword(params.Password)
}
if params.RemainDownloads > 0 {
query.SetRemainDownloads(params.RemainDownloads)
}
if params.Expires != nil {
query.SetNillableExpires(params.Expires)
}
if params.Props != nil {
query.SetProps(params.Props)
}
return query.Save(ctx)
}
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
if err != nil {
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
}
return c.GetByID(ctx, id)
}
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
Where(share.ID(id))).
Where(share.HasUserWith(user.ID(uid))).First(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
}
return s, nil
}
func (c *shareClient) GetByIDs(ctx context.Context, ids []int) ([]*ent.Share, error) {
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.IDIn(ids...))).All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query shares %v: %w", ids, err)
}
return s, nil
}
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
return err
}
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
}
// Viewed increments the view count of the share.
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
return err
}
// Downloaded increments the download count of the share.
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
stm := c.client.Share.
UpdateOneID(share.ID).
AddDownloads(1)
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
stm.AddRemainDownloads(-1)
}
_, err := stm.Save(ctx)
return err
}
func IsValidShare(share *ent.Share) error {
// Check if share is expired
if err := IsShareExpired(share); err != nil {
return err
}
// Check owner status
owner, err := share.Edges.UserOrErr()
if err != nil || owner.Status != user.StatusActive {
// Owner already deleted, or not active.
return ErrOwnerInactive
}
// Check source file status
file, err := share.Edges.FileOrErr()
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
// Source file already deleted
return ErrSourceFileInvalid
}
return nil
}
func IsShareExpired(share *ent.Share) error {
// Check if share is expired
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
return ErrShareLinkExpired
}
return nil
}
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
rawQuery := c.listQuery(args)
query := withShareEagerLoading(ctx, rawQuery)
var (
shares []*ent.Share
err error
paginationRes *PaginationResults
)
if args.UseCursorPagination {
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
} else {
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
}
if err != nil {
return nil, fmt.Errorf("query failed with paginiation: %w", err)
}
return &ListShareResult{
Shares: shares,
PaginationResults: paginationRes,
}, nil
}
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
var (
pageToken *PageToken
err error
)
if args.PageToken != "" {
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
if err != nil {
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
}
}
queryPaged := getShareCursorQuery(args, pageToken, query)
// Use page size + 1 to determine if there are more items to come
queryPaged.Limit(pageSize + 1)
logs, err := queryPaged.
All(ctx)
if err != nil {
return nil, nil, err
}
// More items to come
nextTokenStr := ""
if len(logs) > pageSize {
lastItem := logs[len(logs)-2]
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
}
nextTokenStr = nextToken
}
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
PageSize: pageSize,
NextPageToken: nextTokenStr,
IsCursor: true,
}, nil
}
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
query.Order(getShareOrderOption(args)...)
total, err := query.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
if err != nil {
return nil, nil, err
}
return logs, &PaginationResults{
PageSize: pageSize,
TotalItems: total,
Page: args.Page,
}, nil
}
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
query := c.client.Share.Query()
if args.UserID > 0 {
query.Where(share.HasUserWith(user.ID(args.UserID)))
}
if args.PublicOnly {
query.Where(share.PasswordIsNil())
}
if args.FileID > 0 {
query.Where(share.HasFileWith(file.ID(args.FileID)))
}
if len(args.ShareIDs) > 0 {
query.Where(share.IDIn(args.ShareIDs...))
}
return query
}
// getShareNextPageToken returns the next page token for the given last share.
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
token := &PageToken{
ID: last.ID,
}
return token.Encode(hasher, hashid.EncodeShareID)
}
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
o := &sql.OrderTermOptions{}
getOrderTerm(args.Order)(o)
predicates, ok := shareCursorQuery[args.OrderBy]
if !ok {
predicates = shareCursorQuery[share.FieldID]
}
if token != nil {
query.Where(predicates[o.Desc](token))
}
return query
}
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
share.FieldID: {
true: func(token *PageToken) predicate.Share {
return share.IDLT(token.ID)
},
false: func(token *PageToken) predicate.Share {
return share.IDGT(token.ID)
},
},
}
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
orderTerm := getOrderTerm(args.Order)
switch args.OrderBy {
case share.FieldViews:
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
case share.FieldDownloads:
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
case share.FieldRemainDownloads:
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
default:
return []share.OrderOption{share.ByID(orderTerm)}
}
}
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
q.WithFile(func(q *ent.FileQuery) {
withFileEagerLoading(ctx, q)
})
}
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
q.WithUser(func(q *ent.UserQuery) {
withUserEagerLoading(ctx, q)
})
}
return q
}