mirror of
https://github.com/cloudreve/cloudreve.git
synced 2026-03-05 05:27:00 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fb419d998 | ||
|
|
3f0f33b4fc | ||
|
|
052e6be393 | ||
|
|
a4b0ad81e9 | ||
|
|
8431906b94 | ||
|
|
40476953aa | ||
|
|
270f617b9d | ||
|
|
170f2279c1 | ||
|
|
d1377262e3 | ||
|
|
c9acf7e64e | ||
|
|
4e2f243436 | ||
|
|
a54acd71c2 | ||
|
|
fec2fe14f8 | ||
|
|
1f1bc056e3 | ||
|
|
e44ec0e6bf | ||
|
|
a93b964d8b | ||
|
|
d9cff24c75 | ||
|
|
e2488841b4 | ||
|
|
a276be4098 | ||
|
|
4cf6c81534 | ||
|
|
5a66af3105 | ||
|
|
fc5c67cc20 | ||
|
|
5e226efea1 | ||
|
|
c949d47161 | ||
|
|
e699287ffd | ||
|
|
9c78515c72 |
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. iOS]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Smartphone (please complete the following information):**
|
||||
- Device: [e.g. iPhone6]
|
||||
- OS: [e.g. iOS8.1]
|
||||
- Browser [e.g. stock browser, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-16.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
build:
|
||||
name: Build
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-16.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -27,3 +27,4 @@ version.lock
|
||||
*.ini
|
||||
conf/conf.ini
|
||||
/statik/
|
||||
/vendor/
|
||||
2
assets
2
assets
Submodule assets updated: 92f6981cb3...35c5966f66
11
main.go
11
main.go
@@ -51,12 +51,11 @@ func main() {
|
||||
|
||||
// 如果启用了Unix
|
||||
if conf.UnixConfig.Listen != "" {
|
||||
go func() {
|
||||
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
|
||||
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err)
|
||||
}
|
||||
}()
|
||||
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
|
||||
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
|
||||
|
||||
@@ -90,7 +90,7 @@ func WebDAVAuth() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
expectedUser, err := model.GetUserByEmail(username)
|
||||
expectedUser, err := model.GetActiveUserByEmail(username)
|
||||
if err != nil {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
|
||||
@@ -186,17 +186,17 @@ func (file *File) Rename(new string) error {
|
||||
|
||||
// UpdatePicInfo 更新文件的图像信息
|
||||
func (file *File) UpdatePicInfo(value string) error {
|
||||
return DB.Model(&file).Update("pic_info", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("pic_info", value).Error
|
||||
}
|
||||
|
||||
// UpdateSize 更新文件的大小信息
|
||||
func (file *File) UpdateSize(value uint64) error {
|
||||
return DB.Model(&file).Update("size", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value).Error
|
||||
}
|
||||
|
||||
// UpdateSourceName 更新文件的源文件名
|
||||
func (file *File) UpdateSourceName(value string) error {
|
||||
return DB.Model(&file).Update("source_name", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -44,6 +44,26 @@ func (folder *Folder) GetChild(name string) (*Folder, error) {
|
||||
return &resFolder, err
|
||||
}
|
||||
|
||||
// TraceRoot 向上递归查找父目录
|
||||
func (folder *Folder) TraceRoot() error {
|
||||
if folder.ParentID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parentFolder Folder
|
||||
err := DB.
|
||||
Where("id = ? AND owner_id = ?", folder.ParentID, folder.OwnerID).
|
||||
First(&parentFolder).Error
|
||||
|
||||
if err == nil {
|
||||
err := parentFolder.TraceRoot()
|
||||
folder.Position = path.Join(parentFolder.Position, parentFolder.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChildFolder 查找子目录
|
||||
func (folder *Folder) GetChildFolder() ([]Folder, error) {
|
||||
var folders []Folder
|
||||
|
||||
@@ -530,3 +530,37 @@ func TestFolder_FileInfoInterface(t *testing.T) {
|
||||
asserts.True(folder.IsDir())
|
||||
asserts.Equal("/test", folder.GetPosition())
|
||||
}
|
||||
|
||||
func TestTraceRoot(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
var parentId uint
|
||||
parentId = 5
|
||||
folder := Folder{
|
||||
ParentID: &parentId,
|
||||
OwnerID: 1,
|
||||
Name: "test_name",
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/"))
|
||||
asserts.NoError(folder.TraceRoot())
|
||||
asserts.Equal("/parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 出现错误
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnError(errors.New("error"))
|
||||
asserts.Error(folder.TraceRoot())
|
||||
asserts.Equal("parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ func addDefaultSettings() {
|
||||
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
|
||||
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
|
||||
{Name: "aria2_call_timeout", Value: `5`, Type: "timeout"},
|
||||
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
|
||||
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
|
||||
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
|
||||
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
|
||||
|
||||
@@ -51,6 +51,8 @@ type PolicyOption struct {
|
||||
OdRedirect string `json:"od_redirect,omitempty"`
|
||||
// OdProxy Onedrive 反代地址
|
||||
OdProxy string `json:"od_proxy,omitempty"`
|
||||
// OdDriver OneDrive 驱动器定位符
|
||||
OdDriver string `json:"od_driver,omitempty"`
|
||||
// Region 区域代码
|
||||
Region string `json:"region,omitempty"`
|
||||
// ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段
|
||||
@@ -268,9 +270,8 @@ func (policy *Policy) GetUploadURL() string {
|
||||
return server.ResolveReference(controller).String()
|
||||
}
|
||||
|
||||
// UpdateAccessKey 更新 AccessKey
|
||||
func (policy *Policy) UpdateAccessKey(key string) error {
|
||||
policy.AccessKey = key
|
||||
// SaveAndClearCache 更新并清理缓存
|
||||
func (policy *Policy) SaveAndClearCache() error {
|
||||
err := DB.Save(policy).Error
|
||||
policy.ClearCache()
|
||||
return err
|
||||
|
||||
@@ -257,7 +257,8 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := policy.UpdateAccessKey("123")
|
||||
policy.AccessKey = "123"
|
||||
err := policy.SaveAndClearCache()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
35
models/scripts/reset.go
Normal file
35
models/scripts/reset.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
type ResetAdminPassword int
|
||||
|
||||
func init() {
|
||||
register("ResetAdminPassword", ResetAdminPassword(0))
|
||||
}
|
||||
|
||||
// Run 运行脚本从社区版升级至 Pro 版
|
||||
func (script ResetAdminPassword) Run(ctx context.Context) {
|
||||
// 查找用户
|
||||
user, err := model.GetUserByID(1)
|
||||
if err != nil {
|
||||
util.Log().Panic("初始管理员用户不存在, %s", err)
|
||||
}
|
||||
|
||||
// 生成密码
|
||||
password := util.RandStringRunes(8)
|
||||
|
||||
// 更改为新密码
|
||||
user.SetPassword(password)
|
||||
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
|
||||
util.Log().Panic("密码更改失败, %s", err)
|
||||
}
|
||||
|
||||
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
|
||||
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password))
|
||||
}
|
||||
50
models/scripts/reset_test.go
Normal file
50
models/scripts/reset_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResetAdminPassword_Run(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
script := ResetAdminPassword(0)
|
||||
|
||||
// 初始用户不存在
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}))
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 密码更新失败
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
asserts.NotPanics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
@@ -139,6 +139,13 @@ func GetActiveUserByOpenID(openid string) (User, error) {
|
||||
|
||||
// GetUserByEmail 用Email获取用户
|
||||
func GetUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user)
|
||||
return user, result.Error
|
||||
}
|
||||
|
||||
// GetActiveUserByEmail 用Email获取可登录用户
|
||||
func GetActiveUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user)
|
||||
return user, result.Error
|
||||
|
||||
@@ -352,10 +352,20 @@ func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
func TestGetActiveUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetActiveUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package conf
|
||||
|
||||
// BackendVersion 当前后端版本号
|
||||
var BackendVersion = "3.2.1"
|
||||
var BackendVersion = "3.3.0"
|
||||
|
||||
// RequiredDBVersion 与当前版本匹配的数据库版本
|
||||
var RequiredDBVersion = "3.2.0"
|
||||
var RequiredDBVersion = "3.3.0"
|
||||
|
||||
// RequiredStaticVersion 与当前版本匹配的静态资源版本
|
||||
var RequiredStaticVersion = "3.2.1"
|
||||
var RequiredStaticVersion = "3.3.0"
|
||||
|
||||
// IsPro 是否为Pro版本
|
||||
var IsPro = "false"
|
||||
|
||||
@@ -100,6 +100,14 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
defer file.Close()
|
||||
dst = util.RelativePath(filepath.FromSlash(dst))
|
||||
|
||||
// 如果禁止了 Overwrite,则检查是否有重名冲突
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
if util.Exists(dst) {
|
||||
util.Log().Warning("物理同名文件已存在或不可用: %s", dst)
|
||||
return errors.New("物理同名文件已存在或不可用")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果目标目录不存在,创建
|
||||
basePath := filepath.Dir(dst)
|
||||
if !util.Exists(basePath) {
|
||||
@@ -130,11 +138,14 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
var retErr error
|
||||
|
||||
for _, value := range files {
|
||||
err := os.Remove(util.RelativePath(filepath.FromSlash(value)))
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除文件,%s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
filePath := util.RelativePath(filepath.FromSlash(value))
|
||||
if util.Exists(filePath) {
|
||||
err := os.Remove(filePath)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除文件,%s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试删除文件的缩略图(如果有)
|
||||
|
||||
@@ -2,13 +2,6 @@ package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
@@ -16,12 +9,19 @@ import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandler_Put(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
handler := Driver{}
|
||||
ctx := context.Background()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
os.Remove(util.RelativePath("test/test/txt"))
|
||||
|
||||
testCases := []struct {
|
||||
file io.ReadCloser
|
||||
@@ -33,6 +33,11 @@ func TestHandler_Put(t *testing.T) {
|
||||
dst: "test/test/txt",
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
file: ioutil.NopCloser(strings.NewReader("test input file")),
|
||||
dst: "test/test/txt",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
file: ioutil.NopCloser(strings.NewReader("test input file")),
|
||||
dst: "/notexist:/S.TXT",
|
||||
@@ -55,24 +60,34 @@ func TestHandler_Delete(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
handler := Driver{}
|
||||
ctx := context.Background()
|
||||
filePath := util.RelativePath("test.file")
|
||||
|
||||
file, err := os.Create(util.RelativePath("test.file"))
|
||||
file, err := os.Create(filePath)
|
||||
asserts.NoError(err)
|
||||
_ = file.Close()
|
||||
list, err := handler.Delete(ctx, []string{"test.file"})
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
file, err = os.Create(util.RelativePath("test.file"))
|
||||
asserts.NoError(err)
|
||||
file, err = os.Create(filePath)
|
||||
_ = file.Close()
|
||||
file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0))
|
||||
asserts.NoError(err)
|
||||
list, err = handler.Delete(ctx, []string{"test.file", "test.notexist"})
|
||||
asserts.Equal([]string{"test.notexist"}, list)
|
||||
asserts.Error(err)
|
||||
file.Close()
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
list, err = handler.Delete(ctx, []string{"test.notexist"})
|
||||
asserts.Equal([]string{"test.notexist"}, list)
|
||||
asserts.Error(err)
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
file, err = os.Create(filePath)
|
||||
asserts.NoError(err)
|
||||
list, err = handler.Delete(ctx, []string{"test.file"})
|
||||
_ = file.Close()
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
func TestHandler_Get(t *testing.T) {
|
||||
|
||||
@@ -53,12 +53,23 @@ func (err RespError) Error() string {
|
||||
return err.APIError.Message
|
||||
}
|
||||
|
||||
func (client *Client) getRequestURL(api string) string {
|
||||
func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(client.Endpoints.EndpointURL)
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
base.Path = path.Join(base.Path, api)
|
||||
|
||||
if options.useDriverResource {
|
||||
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
|
||||
} else {
|
||||
base.Path = path.Join(base.Path, api)
|
||||
}
|
||||
|
||||
return base.String()
|
||||
}
|
||||
|
||||
@@ -67,9 +78,9 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
var requestURL string
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
if dst == "" {
|
||||
requestURL = client.getRequestURL("me/drive/root/children")
|
||||
requestURL = client.getRequestURL("root/children")
|
||||
} else {
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst + ":/children")
|
||||
requestURL = client.getRequestURL("root:/" + dst + ":/children")
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
|
||||
@@ -103,10 +114,10 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
var requestURL string
|
||||
if id != "" {
|
||||
requestURL = client.getRequestURL("/me/drive/items/" + id)
|
||||
requestURL = client.getRequestURL("items/" + id)
|
||||
} else {
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst)
|
||||
requestURL = client.getRequestURL("root:/" + dst)
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
|
||||
@@ -129,14 +140,13 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn
|
||||
|
||||
// CreateUploadSession 创建分片上传会话
|
||||
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/createUploadSession")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
|
||||
body := map[string]map[string]interface{}{
|
||||
"item": {
|
||||
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
|
||||
@@ -161,6 +171,33 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts
|
||||
return uploadSession.UploadURL, nil
|
||||
}
|
||||
|
||||
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
|
||||
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
siteUrlParsed, err := url.Parse(siteUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hostName := siteUrlParsed.Hostname()
|
||||
relativePath := strings.Trim(siteUrlParsed.Path, "/")
|
||||
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
|
||||
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if reqErr != nil {
|
||||
return "", reqErr
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
siteInfo Site
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
return siteInfo.ID, nil
|
||||
}
|
||||
|
||||
// GetUploadSessionStatus 查询上传会话状态
|
||||
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
|
||||
@@ -220,15 +257,21 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, chunk *
|
||||
|
||||
// Upload 上传文件
|
||||
func (client *Client) Upload(ctx context.Context, dst string, size int, file io.Reader) error {
|
||||
// 决定是否覆盖文件
|
||||
overwrite := "replace"
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = "fail"
|
||||
}
|
||||
|
||||
// 小文件,使用简单上传接口上传
|
||||
if size <= int(SmallFileSize) {
|
||||
_, err := client.SimpleUpload(ctx, dst, file, int64(size))
|
||||
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
|
||||
return err
|
||||
}
|
||||
|
||||
// 大文件,进行分片
|
||||
// 创建上传会话
|
||||
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior("replace"))
|
||||
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -287,9 +330,15 @@ func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string)
|
||||
}
|
||||
|
||||
// SimpleUpload 上传小文件到dst
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64) (*UploadResult, error) {
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/content")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/content")
|
||||
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
|
||||
|
||||
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
|
||||
request.WithTimeout(time.Duration(150)*time.Second),
|
||||
@@ -303,7 +352,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read
|
||||
retried++
|
||||
util.Log().Debug("文件[%s]上传失败[%s],5秒钟后重试", dst, err)
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
return client.SimpleUpload(context.WithValue(ctx, fsctx.RetryCtx, retried), dst, body, size)
|
||||
return client.SimpleUpload(context.WithValue(ctx, fsctx.RetryCtx, retried), dst, body, size, opts...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -345,7 +394,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
|
||||
// 由于API限制,最多删除20个
|
||||
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
body := client.makeBatchDeleteRequestsBody(dst)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch"), body, 200)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
|
||||
WithDriverResource(false)), body, 200)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
@@ -370,7 +420,7 @@ func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error
|
||||
func getDeleteFailed(res *BatchResponses) []string {
|
||||
var failed = make([]string, 0, len(res.Responses))
|
||||
for _, v := range res.Responses {
|
||||
if v.Status != 204 {
|
||||
if v.Status != 204 && v.Status != 404 {
|
||||
failed = append(failed, v.ID)
|
||||
}
|
||||
}
|
||||
@@ -384,7 +434,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
}
|
||||
for i, v := range files {
|
||||
v = strings.TrimPrefix(v, "/")
|
||||
filePath, _ := url.Parse("/me/drive/root:/")
|
||||
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
|
||||
filePath.Path = path.Join(filePath.Path, v)
|
||||
req.Requests[i] = BatchRequest{
|
||||
ID: v,
|
||||
@@ -400,17 +450,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
// GetThumbURL 获取给定尺寸的缩略图URL
|
||||
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
var (
|
||||
cropOption string
|
||||
requestURL string
|
||||
)
|
||||
if client.Endpoints.isInChina {
|
||||
cropOption = "large"
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails/0") + "/" + cropOption
|
||||
} else {
|
||||
cropOption = fmt.Sprintf("c%dx%d_Crop", w, h)
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails") + "?select=" + cropOption
|
||||
}
|
||||
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if err != nil {
|
||||
@@ -431,7 +471,7 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s
|
||||
}
|
||||
|
||||
if len(thumbRes.Value) == 1 {
|
||||
if res, ok := thumbRes.Value[0][cropOption]; ok {
|
||||
if res, ok := thumbRes.Value[0]["large"]; ok {
|
||||
return res.(map[string]interface{})["url"].(string), nil
|
||||
}
|
||||
}
|
||||
@@ -456,7 +496,7 @@ func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size ui
|
||||
case <-time.After(time.Duration(ttl) * time.Second):
|
||||
// 上传会话到期,仍未完成上传,创建占位符
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
@@ -504,7 +544,7 @@ func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size ui
|
||||
// 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传,
|
||||
// 所以上传一个空文件占位,阻止客户端上传
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
@@ -166,6 +167,82 @@ func TestClient_GetRequestURL(t *testing.T) {
|
||||
client.Endpoints.EndpointURL = string([]byte{0x7f})
|
||||
asserts.Equal("", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 不使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetSiteIDByURL(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
client, _ := NewClient(&model.Policy{})
|
||||
client.Credential.AccessToken = "AccessToken"
|
||||
|
||||
// 请求失败
|
||||
{
|
||||
client.Credential.ExpiresIn = 0
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
|
||||
}
|
||||
|
||||
// 返回未知响应
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`???`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
}
|
||||
|
||||
// 返回正常
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.NoError(err)
|
||||
asserts.NotEmpty(res)
|
||||
asserts.Equal("123321", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Meta(t *testing.T) {
|
||||
@@ -499,11 +576,12 @@ func TestClient_Upload(t *testing.T) {
|
||||
client, _ := NewClient(&model.Policy{})
|
||||
client.Credential.AccessToken = "AccessToken"
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
|
||||
// 小文件,简单上传,失败
|
||||
{
|
||||
client.Credential.ExpiresIn = 0
|
||||
err := client.Upload(context.Background(), "123.jpg", 3, strings.NewReader("123"))
|
||||
err := client.Upload(ctx, "123.jpg", 3, strings.NewReader("123"))
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
@@ -888,7 +966,7 @@ func TestClient_GetThumbURL(t *testing.T) {
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"c1x1_Crop":{"url":"thumb"}}]}`)),
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"large":{"url":"thumb"}}]}`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
|
||||
@@ -37,14 +37,16 @@ type Endpoints struct {
|
||||
OAuthEndpoints *oauthEndpoint
|
||||
EndpointURL string // 接口请求的基URL
|
||||
isInChina bool // 是否为世纪互联
|
||||
DriverResource string // 要使用的驱动器
|
||||
}
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *model.Policy) (*Client, error) {
|
||||
client := &Client{
|
||||
Endpoints: &Endpoints{
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
DriverResource: policy.OptionsSerialized.OdDriver,
|
||||
},
|
||||
Credential: &Credential{
|
||||
RefreshToken: policy.AccessKey,
|
||||
@@ -56,6 +58,10 @@ func NewClient(policy *model.Policy) (*Client, error) {
|
||||
Request: request.HTTPClient{},
|
||||
}
|
||||
|
||||
if client.Endpoints.DriverResource == "" {
|
||||
client.Endpoints.DriverResource = "me/drive"
|
||||
}
|
||||
|
||||
oauthBase := client.getOAuthEndpoint()
|
||||
if oauthBase == nil {
|
||||
return nil, ErrAuthEndpoint
|
||||
|
||||
@@ -152,8 +152,13 @@ func (handler Driver) Source(
|
||||
isDownload bool,
|
||||
speed int,
|
||||
) (string, error) {
|
||||
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
|
||||
}
|
||||
|
||||
// 尝试从缓存中查找
|
||||
if cachedURL, ok := cache.Get(fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)); ok {
|
||||
if cachedURL, ok := cache.Get(cacheKey); ok {
|
||||
return handler.replaceSourceHost(cachedURL.(string))
|
||||
}
|
||||
|
||||
@@ -162,7 +167,7 @@ func (handler Driver) Source(
|
||||
if err == nil {
|
||||
// 写入新的缓存
|
||||
cache.Set(
|
||||
fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path),
|
||||
cacheKey,
|
||||
res.DownloadURL,
|
||||
model.GetIntSetting("onedrive_source_timeout", 1800),
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -150,6 +151,21 @@ func TestDriver_Source(t *testing.T) {
|
||||
asserts.Equal("res", res)
|
||||
}
|
||||
|
||||
// 命中缓存 上下文存在文件 成功
|
||||
{
|
||||
file := model.File{}
|
||||
file.ID = 1
|
||||
file.UpdatedAt = time.Now()
|
||||
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
handler.Client.Credential.AccessToken = "1"
|
||||
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
|
||||
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
|
||||
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
|
||||
asserts.NoError(err)
|
||||
asserts.Equal("res", res)
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
|
||||
@@ -160,7 +160,8 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
|
||||
client.Credential = credential
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
client.Policy.UpdateAccessKey(credential.RefreshToken)
|
||||
client.Policy.AccessKey = credential.RefreshToken
|
||||
client.Policy.SaveAndClearCache()
|
||||
|
||||
// 更新缓存
|
||||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||||
|
||||
@@ -8,11 +8,12 @@ type Option interface {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
useDriverResource bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
@@ -38,13 +39,21 @@ func WithConflictBehavior(t string) Option {
|
||||
})
|
||||
}
|
||||
|
||||
// WithConflictBehavior 设置文件重名后的处理方式
|
||||
func WithDriverResource(t bool) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.useDriverResource = t
|
||||
})
|
||||
}
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
conflictBehavior: "fail",
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
conflictBehavior: "fail",
|
||||
useDriverResource: true,
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +131,15 @@ type OAuthError struct {
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
}
|
||||
|
||||
// Site SharePoint 站点信息
|
||||
type Site struct {
|
||||
Description string `json:"description"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
WebUrl string `json:"webUrl"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
@@ -235,8 +235,15 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
// 凭证有效期
|
||||
credentialTTL := model.GetIntSetting("upload_credential_timeout", 3600)
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := true
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = false
|
||||
}
|
||||
|
||||
options := []oss.Option{
|
||||
oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)),
|
||||
oss.ForbidOverWrite(!overwrite),
|
||||
}
|
||||
|
||||
// 上传文件
|
||||
|
||||
@@ -265,10 +265,11 @@ func TestDriver_Put(t *testing.T) {
|
||||
},
|
||||
}
|
||||
cache.Set("setting_upload_credential_timeout", "3600", 0)
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
|
||||
// 失败
|
||||
{
|
||||
err := handler.Put(context.Background(), ioutil.NopCloser(strings.NewReader("123")), "/123.txt", 3)
|
||||
err := handler.Put(ctx, ioutil.NopCloser(strings.NewReader("123")), "/123.txt", 3)
|
||||
asserts.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
if err != nil {
|
||||
failed := make([]string, 0, len(rets))
|
||||
for k, ret := range rets {
|
||||
if ret.Code != 200 {
|
||||
if ret.Code != 200 && ret.Code != 612 {
|
||||
failed = append(failed, files[k])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,6 +155,13 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 决定是否要禁用文件覆盖
|
||||
overwrite := "true"
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = "false"
|
||||
}
|
||||
|
||||
// 上传文件
|
||||
resp, err := handler.Client.Request(
|
||||
"POST",
|
||||
@@ -164,6 +171,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
"Authorization": {credential.Token},
|
||||
"X-Policy": {credential.Policy},
|
||||
"X-FileName": {fileName},
|
||||
"X-Overwrite": {overwrite},
|
||||
}),
|
||||
request.WithContentLength(int64(size)),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
@@ -321,7 +329,8 @@ func (handler Driver) getUploadCredential(ctx context.Context, policy serializer
|
||||
// 签名上传策略
|
||||
uploadRequest, _ := http.NewRequest("POST", "/api/v3/slave/upload", nil)
|
||||
uploadRequest.Header = map[string][]string{
|
||||
"X-Policy": {policyEncoded},
|
||||
"X-Policy": {policyEncoded},
|
||||
"X-Overwrite": {"false"},
|
||||
}
|
||||
auth.SignRequest(handler.AuthInstance, uploadRequest, TTL)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestHandler_Token(t *testing.T) {
|
||||
},
|
||||
AuthInstance: auth.HMACAuth{},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
auth.General = auth.HMACAuth{SecretKey: []byte("test")}
|
||||
|
||||
// 成功
|
||||
@@ -49,6 +49,7 @@ func TestHandler_Token(t *testing.T) {
|
||||
asserts.Equal(true, policy.AutoRename)
|
||||
asserts.Equal("dir", policy.SavePath)
|
||||
asserts.Equal("file", policy.FileName)
|
||||
asserts.Equal("file", policy.FileName)
|
||||
asserts.Equal([]string{"txt"}, policy.AllowedExtension)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
@@ -60,7 +60,7 @@ func (handler *Driver) InitS3Client() error {
|
||||
Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
|
||||
Endpoint: &handler.Policy.Server,
|
||||
Region: &handler.Policy.OptionsSerialized.Region,
|
||||
S3ForcePathStyle: aws.Bool(false),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -229,53 +229,35 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
return files, err
|
||||
}
|
||||
|
||||
var (
|
||||
failed = make([]string, 0, len(files))
|
||||
lastErr error
|
||||
currentIndex = 0
|
||||
indexLock sync.Mutex
|
||||
failedLock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
routineNum = 4
|
||||
)
|
||||
wg.Add(routineNum)
|
||||
failed := make([]string, 0, len(files))
|
||||
deleted := make([]string, 0, len(files))
|
||||
|
||||
// S3不支持批量操作,这里开四个协程并行操作
|
||||
for i := 0; i < routineNum; i++ {
|
||||
go func() {
|
||||
for {
|
||||
// 取得待删除文件
|
||||
indexLock.Lock()
|
||||
if currentIndex >= len(files) {
|
||||
// 所有文件处理完成
|
||||
wg.Done()
|
||||
indexLock.Unlock()
|
||||
return
|
||||
}
|
||||
path := files[currentIndex]
|
||||
currentIndex++
|
||||
indexLock.Unlock()
|
||||
|
||||
// 发送异步删除请求
|
||||
_, err := handler.svc.DeleteObject(
|
||||
&s3.DeleteObjectInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &path,
|
||||
})
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
failedLock.Lock()
|
||||
lastErr = err
|
||||
failed = append(failed, path)
|
||||
failedLock.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
keys := make([]*s3.ObjectIdentifier, 0, len(files))
|
||||
for _, file := range files {
|
||||
filePath := file
|
||||
keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return failed, lastErr
|
||||
// 发送异步删除请求
|
||||
res, err := handler.svc.DeleteObjects(
|
||||
&s3.DeleteObjectsInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Delete: &s3.Delete{
|
||||
Objects: keys,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 统计未删除的文件
|
||||
for _, deleteRes := range res.Deleted {
|
||||
deleted = append(deleted, *deleteRes.Key)
|
||||
}
|
||||
failed = util.SliceDifference(failed, deleted)
|
||||
|
||||
return failed, nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
@@ -11,6 +9,8 @@ import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/juju/ratelimit"
|
||||
"io"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
/* ============
|
||||
@@ -288,8 +288,11 @@ func (fs *FileSystem) signURL(ctx context.Context, file *model.File, ttl int64,
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeNotSet, "无法获取外链", err)
|
||||
}
|
||||
|
||||
return source, nil
|
||||
// 阿里云的 golang SDK 会把整个object KEY也编码 临时解决方案是清空`RawPath`让golang的`url.EscapedPath`修正这个问题
|
||||
// https://github.com/cloudreve/Cloudreve/issues/677 https://github.com/aliyun/aliyun-oss-go-sdk/blob/6f7e8f88c64181cc2d86d8bd46090b68851e645a/oss/conn.go#L767
|
||||
sourceUrl, _ := url.Parse(source)
|
||||
sourceUrl.RawPath = ""
|
||||
return sourceUrl.String(), nil
|
||||
}
|
||||
|
||||
// ResetFileIfNotExist 重设当前目标文件为 path,如果当前目标为空
|
||||
|
||||
@@ -29,8 +29,8 @@ const (
|
||||
ShareKeyCtx
|
||||
// LimitParentCtx 限制父目录
|
||||
LimitParentCtx
|
||||
// IgnoreConflictCtx 忽略重名冲突
|
||||
IgnoreConflictCtx
|
||||
// IgnoreDirectoryConflictCtx 忽略目录重名冲突
|
||||
IgnoreDirectoryConflictCtx
|
||||
// RetryCtx 失败重试次数
|
||||
RetryCtx
|
||||
// ForceUsePublicEndpointCtx 强制使用公网 Endpoint
|
||||
@@ -39,4 +39,6 @@ const (
|
||||
CancelFuncCtx
|
||||
// ValidateCapacityOnceCtx 限定归还容量的操作只執行一次
|
||||
ValidateCapacityOnceCtx
|
||||
// 禁止上传时同名覆盖操作
|
||||
DisableOverwrite
|
||||
)
|
||||
|
||||
@@ -39,8 +39,8 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR
|
||||
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
|
||||
}
|
||||
|
||||
// 出错时重新生成缩略图
|
||||
if err != nil {
|
||||
// 本地存储策略出错时重新生成缩略图
|
||||
if err != nil && fs.Policy.Type == "local" {
|
||||
fs.GenerateThumbnail(ctx, &fs.FileTarget[0])
|
||||
}
|
||||
|
||||
|
||||
@@ -403,8 +403,8 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*mo
|
||||
isExist, parent := fs.IsPathExist(base)
|
||||
if !isExist {
|
||||
// 递归创建父目录
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreConflictCtx, true)
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
|
||||
}
|
||||
newParent, err := fs.CreateDirectory(ctx, base)
|
||||
if err != nil {
|
||||
@@ -427,7 +427,7 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*mo
|
||||
_, err := newFolder.Create()
|
||||
|
||||
if err != nil {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
return nil, ErrFolderExisted
|
||||
}
|
||||
|
||||
|
||||
23
pkg/serializer/explorer.go
Normal file
23
pkg/serializer/explorer.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package serializer
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(ObjectProps{})
|
||||
}
|
||||
|
||||
// ObjectProps 文件、目录对象的详细属性信息
|
||||
type ObjectProps struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Policy string `json:"policy"`
|
||||
Size uint64 `json:"size"`
|
||||
ChildFolderNum int `json:"child_folder_num"`
|
||||
ChildFileNum int `json:"child_file_num"`
|
||||
Path string `json:"path"`
|
||||
|
||||
QueryDate time.Time `json:"query_date"`
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
)
|
||||
|
||||
// DecompressTask 文件压缩任务
|
||||
@@ -81,7 +82,12 @@ func (job *DecompressTask) Do() {
|
||||
}
|
||||
|
||||
job.TaskModel.SetProgress(DecompressingProgress)
|
||||
err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst)
|
||||
|
||||
// 禁止重名覆盖
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
|
||||
err = fs.Decompress(ctx, job.TaskProps.Src, job.TaskProps.Dst)
|
||||
if err != nil {
|
||||
job.SetErrorMsg("解压缩失败", err)
|
||||
return
|
||||
|
||||
@@ -102,7 +102,7 @@ func (job *ImportTask) Do() {
|
||||
|
||||
// 列取目录、对象
|
||||
job.TaskModel.SetProgress(ListingProgress)
|
||||
coxIgnoreConflict := context.WithValue(context.Background(), fsctx.IgnoreConflictCtx,
|
||||
coxIgnoreConflict := context.WithValue(context.Background(), fsctx.IgnoreDirectoryConflictCtx,
|
||||
true)
|
||||
objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
@@ -102,7 +103,8 @@ func (job *TransferTask) Do() {
|
||||
dst = path.Join(job.TaskProps.Dst, strings.TrimPrefix(src, trim))
|
||||
}
|
||||
|
||||
err = fs.UploadFromPath(context.Background(), file, dst)
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
err = fs.UploadFromPath(ctx, file, dst)
|
||||
if err != nil {
|
||||
job.SetErrorMsg("文件转存失败", err)
|
||||
}
|
||||
|
||||
@@ -373,6 +373,9 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst
|
||||
fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity)
|
||||
}
|
||||
|
||||
// 禁止覆盖
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
|
||||
// 执行上传
|
||||
err = fs.Upload(ctx, fileData)
|
||||
if err != nil {
|
||||
@@ -407,8 +410,8 @@ func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request, fs *filesy
|
||||
return http.StatusUnsupportedMediaType, nil
|
||||
}
|
||||
if strings.Contains(r.UserAgent(), "rclone") {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreConflictCtx, true)
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
|
||||
}
|
||||
}
|
||||
if _, err := fs.CreateDirectory(ctx, reqPath); err != nil {
|
||||
|
||||
@@ -319,6 +319,7 @@ func FileUploadStream(c *gin.Context) {
|
||||
|
||||
// 执行上传
|
||||
ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{})
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c)
|
||||
err = fs.Upload(uploadCtx, fileData)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,3 +66,19 @@ func Rename(c *gin.Context) {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Rename 重命名文件或目录
|
||||
func GetProperty(c *gin.Context) {
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var service explorer.ItemPropertyService
|
||||
service.ID = c.Param("id")
|
||||
if err := c.ShouldBindQuery(&service); err == nil {
|
||||
res := service.GetProperty(ctx, c)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,11 @@ func SlaveUpload(c *gin.Context) {
|
||||
fs.Use("AfterUpload", filesystem.SlaveAfterUpload)
|
||||
fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
|
||||
|
||||
// 是否允许覆盖
|
||||
if c.Request.Header.Get("X-Overwrite") == "false" {
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
}
|
||||
|
||||
// 执行上传
|
||||
err = fs.Upload(ctx, fileData)
|
||||
if err != nil {
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
// StartLoginAuthn 开始注册WebAuthn登录
|
||||
func StartLoginAuthn(c *gin.Context) {
|
||||
userName := c.Param("username")
|
||||
expectedUser, err := model.GetUserByEmail(userName)
|
||||
expectedUser, err := model.GetActiveUserByEmail(userName)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotFound, "用户不存在", err))
|
||||
return
|
||||
@@ -52,7 +52,7 @@ func StartLoginAuthn(c *gin.Context) {
|
||||
// FinishLoginAuthn 完成注册WebAuthn登录
|
||||
func FinishLoginAuthn(c *gin.Context) {
|
||||
userName := c.Param("username")
|
||||
expectedUser, err := model.GetUserByEmail(userName)
|
||||
expectedUser, err := model.GetActiveUserByEmail(userName)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err))
|
||||
return
|
||||
@@ -349,6 +349,8 @@ func UpdateOption(c *gin.Context) {
|
||||
subService = &user.DeleteWebAuthn{}
|
||||
case "theme":
|
||||
subService = &user.ThemeChose{}
|
||||
default:
|
||||
subService = &user.ChangerNick{}
|
||||
}
|
||||
|
||||
subErr = c.ShouldBindJSON(subService)
|
||||
|
||||
@@ -510,6 +510,8 @@ func InitMasterRouter() *gin.Engine {
|
||||
object.POST("copy", controllers.Copy)
|
||||
// 重命名对象
|
||||
object.POST("rename", controllers.Rename)
|
||||
// 获取对象属性
|
||||
object.GET("property/:id", controllers.GetProperty)
|
||||
}
|
||||
|
||||
// 分享
|
||||
|
||||
@@ -2,6 +2,7 @@ package callback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OneDriveOauthService OneDrive 授权回调服务
|
||||
@@ -41,17 +43,42 @@ func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err)
|
||||
}
|
||||
|
||||
credential, err := client.ObtainToken(context.Background(), onedrive.WithCode(service.Code))
|
||||
credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code))
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err)
|
||||
}
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
if err := client.Policy.UpdateAccessKey(credential.RefreshToken); err != nil {
|
||||
client.Policy.AccessKey = credential.RefreshToken
|
||||
if err := client.Policy.SaveAndClearCache(); err != nil {
|
||||
return serializer.DBErr("无法更新 RefreshToken", err)
|
||||
}
|
||||
|
||||
cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_")
|
||||
if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") {
|
||||
if err := querySharePointSiteID(c, client.Policy); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法查询 SharePoint 站点 ID", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
}
|
||||
|
||||
func querySharePointSiteID(ctx context.Context, policy *model.Policy) error {
|
||||
client, err := onedrive.NewClient(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id)
|
||||
if err := client.Policy.SaveAndClearCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
|
||||
// 上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
|
||||
// 给文件系统分配钩子
|
||||
fs.Use("BeforeUpload", filesystem.HookValidateFile)
|
||||
|
||||
@@ -60,6 +60,13 @@ type ItemDecompressService struct {
|
||||
Dst string `json:"dst" binding:"required,min=1,max=65535"`
|
||||
}
|
||||
|
||||
// ItemPropertyService 获取对象属性服务
|
||||
type ItemPropertyService struct {
|
||||
ID string `binding:"required"`
|
||||
TraceRoot bool `form:"trace_root"`
|
||||
IsFolder bool `form:"is_folder"`
|
||||
}
|
||||
|
||||
// Raw 批量解码HashID,获取原始ID
|
||||
func (service *ItemIDService) Raw() *ItemService {
|
||||
if service.Source != nil {
|
||||
@@ -353,3 +360,100 @@ func (service *ItemRenameService) Rename(ctx context.Context, c *gin.Context) se
|
||||
Code: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProperty 获取对象的属性
|
||||
func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Context) serializer.Response {
|
||||
userCtx, _ := c.Get("user")
|
||||
user := userCtx.(*model.User)
|
||||
|
||||
var props serializer.ObjectProps
|
||||
props.QueryDate = time.Now()
|
||||
|
||||
// 如果是文件对象
|
||||
if !service.IsFolder {
|
||||
res, err := hashid.DecodeHashID(service.ID, hashid.FileID)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotFound, "对象不存在", err)
|
||||
}
|
||||
|
||||
file, err := model.GetFilesByIDs([]uint{res}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到文件", err)
|
||||
}
|
||||
|
||||
props.CreatedAt = file[0].CreatedAt
|
||||
props.UpdatedAt = file[0].UpdatedAt
|
||||
props.Policy = file[0].GetPolicy().Name
|
||||
props.Size = file[0].Size
|
||||
|
||||
// 查找父目录
|
||||
if service.TraceRoot {
|
||||
parent, err := model.GetFoldersByIDs([]uint{file[0].FolderID}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到父目录", err)
|
||||
}
|
||||
|
||||
if err := parent[0].TraceRoot(); err != nil {
|
||||
return serializer.DBErr("无法溯源父目录", err)
|
||||
}
|
||||
|
||||
props.Path = path.Join(parent[0].Position, parent[0].Name)
|
||||
}
|
||||
} else {
|
||||
res, err := hashid.DecodeHashID(service.ID, hashid.FolderID)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotFound, "对象不存在", err)
|
||||
}
|
||||
|
||||
// 如果对象是目录, 先尝试返回缓存结果
|
||||
if cacheRes, ok := cache.Get(fmt.Sprintf("folder_props_%d", res)); ok {
|
||||
return serializer.Response{Data: cacheRes.(serializer.ObjectProps)}
|
||||
}
|
||||
|
||||
folder, err := model.GetFoldersByIDs([]uint{res}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到目录", err)
|
||||
}
|
||||
|
||||
props.CreatedAt = folder[0].CreatedAt
|
||||
props.UpdatedAt = folder[0].UpdatedAt
|
||||
|
||||
// 统计子目录
|
||||
childFolders, err := model.GetRecursiveChildFolder([]uint{folder[0].ID},
|
||||
user.ID, true)
|
||||
if err != nil {
|
||||
return serializer.DBErr("无法列取子目录", err)
|
||||
}
|
||||
props.ChildFolderNum = len(childFolders) - 1
|
||||
|
||||
// 统计子文件
|
||||
files, err := model.GetChildFilesOfFolders(&childFolders)
|
||||
if err != nil {
|
||||
return serializer.DBErr("无法列取子文件", err)
|
||||
}
|
||||
|
||||
// 统计子文件个数和大小
|
||||
props.ChildFileNum = len(files)
|
||||
for i := 0; i < len(files); i++ {
|
||||
props.Size += files[i].Size
|
||||
}
|
||||
|
||||
// 查找父目录
|
||||
if service.TraceRoot {
|
||||
if err := folder[0].TraceRoot(); err != nil {
|
||||
return serializer.DBErr("无法溯源父目录", err)
|
||||
}
|
||||
|
||||
props.Path = folder[0].Position
|
||||
}
|
||||
|
||||
// 如果列取对象是目录,则缓存结果
|
||||
cache.Set(fmt.Sprintf("folder_props_%d", res), props,
|
||||
model.GetIntSetting("folder_props_timeout", 300))
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: 0,
|
||||
Data: props,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,12 @@ func (service *UserResetEmailService) Reset(c *gin.Context) serializer.Response
|
||||
// 查找用户
|
||||
if user, err := model.GetUserByEmail(service.UserName); err == nil {
|
||||
|
||||
if user.Status == model.Baned || user.Status == model.OveruseBaned {
|
||||
return serializer.Err(403, "该账号已被封禁", nil)
|
||||
}
|
||||
if user.Status == model.NotActivicated {
|
||||
return serializer.Err(403, "该账号未激活", nil)
|
||||
}
|
||||
// 创建密码重设会话
|
||||
secret := util.RandStringRunes(32)
|
||||
cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600)
|
||||
|
||||
@@ -64,10 +64,17 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
|
||||
user.Status = model.NotActivicated
|
||||
}
|
||||
user.GroupID = uint(defaultGroup)
|
||||
|
||||
userNotActivated := false
|
||||
// 创建用户
|
||||
if err := model.DB.Create(&user).Error; err != nil {
|
||||
return serializer.DBErr("此邮箱已被使用", err)
|
||||
//检查已存在使用者是否尚未激活
|
||||
expectedUser, err := model.GetUserByEmail(service.UserName)
|
||||
if expectedUser.Status == model.NotActivicated {
|
||||
userNotActivated = true
|
||||
user = expectedUser
|
||||
} else {
|
||||
return serializer.DBErr("此邮箱已被使用", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送激活邮件
|
||||
@@ -100,8 +107,12 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
|
||||
if err := email.Send(user.Email, title, body); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法发送激活邮件", err)
|
||||
}
|
||||
|
||||
return serializer.Response{Code: 203}
|
||||
if userNotActivated == true {
|
||||
//原本在上面要抛出的DBErr,放来这边抛出
|
||||
return serializer.DBErr("用户未激活,已重新发送激活邮件", nil)
|
||||
} else {
|
||||
return serializer.Response{Code: 203}
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
@@ -200,8 +201,8 @@ func (service *AvatarService) Get(c *gin.Context) serializer.Response {
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法解析 Gravatar 服务器地址", err)
|
||||
}
|
||||
|
||||
has := md5.Sum([]byte(user.Email))
|
||||
email_lowered := strings.ToLower(user.Email)
|
||||
has := md5.Sum([]byte(email_lowered))
|
||||
avatar, _ := url.Parse(fmt.Sprintf("/avatar/%x?d=mm&s=%s", has, sizes[service.Size]))
|
||||
|
||||
return serializer.Response{
|
||||
|
||||
Reference in New Issue
Block a user