mirror of
https://github.com/cloudreve/cloudreve.git
synced 2026-03-06 01:07:01 +00:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1252c810b | ||
|
|
e781185ad2 | ||
|
|
95802efcec | ||
|
|
233648b956 | ||
|
|
53acadf098 | ||
|
|
c0f7214cdb | ||
|
|
ccaefdab33 | ||
|
|
6efd8e8183 | ||
|
|
144b534486 | ||
|
|
e160154d3b | ||
|
|
2381eca230 | ||
|
|
adde486a30 | ||
|
|
a9c0d6ed17 | ||
|
|
595f4a1350 | ||
|
|
a5f80a4431 | ||
|
|
6fb419d998 | ||
|
|
3f0f33b4fc | ||
|
|
052e6be393 | ||
|
|
a4b0ad81e9 | ||
|
|
8431906b94 | ||
|
|
40476953aa | ||
|
|
270f617b9d | ||
|
|
170f2279c1 | ||
|
|
d1377262e3 | ||
|
|
c9acf7e64e | ||
|
|
4e2f243436 | ||
|
|
a54acd71c2 | ||
|
|
fec2fe14f8 | ||
|
|
1f1bc056e3 | ||
|
|
e44ec0e6bf | ||
|
|
a93b964d8b | ||
|
|
d9cff24c75 | ||
|
|
e2488841b4 | ||
|
|
a276be4098 | ||
|
|
4cf6c81534 | ||
|
|
5a66af3105 | ||
|
|
fc5c67cc20 | ||
|
|
5e226efea1 | ||
|
|
c949d47161 | ||
|
|
e699287ffd | ||
|
|
9c78515c72 |
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. iOS]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Smartphone (please complete the following information):**
|
||||
- Device: [e.g. iPhone6]
|
||||
- OS: [e.g. iOS8.1]
|
||||
- Browser [e.g. stock browser, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
29
.github/workflows/build.yml
vendored
29
.github/workflows/build.yml
vendored
@@ -5,36 +5,9 @@ on:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get github.com/rakyll/statik
|
||||
export PATH=$PATH:~/go/bin/
|
||||
statik -src=models -f
|
||||
|
||||
- name: Test
|
||||
run: go test -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
build:
|
||||
name: Build
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-16.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
|
||||
47
.github/workflows/test.yml
vendored
Normal file
47
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
push:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-16.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get github.com/rakyll/statik
|
||||
export PATH=$PATH:~/go/bin/
|
||||
statik -src=models -f
|
||||
|
||||
- name: Test
|
||||
run: go test -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Upload binary files (linux_arm)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm
|
||||
path: release/cloudreve*linux_arm.*
|
||||
|
||||
- name: Upload binary files (linux_arm64)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm64
|
||||
path: release/cloudreve*linux_arm64.*
|
||||
2
assets
2
assets
Submodule assets updated: 92f6981cb3...522a75c750
11
main.go
11
main.go
@@ -51,12 +51,11 @@ func main() {
|
||||
|
||||
// 如果启用了Unix
|
||||
if conf.UnixConfig.Listen != "" {
|
||||
go func() {
|
||||
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
|
||||
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err)
|
||||
}
|
||||
}()
|
||||
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
|
||||
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
|
||||
|
||||
@@ -90,7 +90,7 @@ func WebDAVAuth() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
expectedUser, err := model.GetUserByEmail(username)
|
||||
expectedUser, err := model.GetActiveUserByEmail(username)
|
||||
if err != nil {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
|
||||
122
middleware/captcha.go
Normal file
122
middleware/captcha.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
captcha "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha/v20190722"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type req struct {
|
||||
CaptchaCode string `json:"captchaCode"`
|
||||
Ticket string `json:"ticket"`
|
||||
Randstr string `json:"randstr"`
|
||||
}
|
||||
|
||||
// CaptchaRequired 验证请求签名
|
||||
func CaptchaRequired(configName string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 相关设定
|
||||
options := model.GetSettingByNames(configName,
|
||||
"captcha_type",
|
||||
"captcha_ReCaptchaSecret",
|
||||
"captcha_TCaptcha_SecretId",
|
||||
"captcha_TCaptcha_SecretKey",
|
||||
"captcha_TCaptcha_CaptchaAppId",
|
||||
"captcha_TCaptcha_AppSecretKey")
|
||||
// 检查验证码
|
||||
isCaptchaRequired := model.IsTrueVal(options[configName])
|
||||
|
||||
if isCaptchaRequired {
|
||||
var service req
|
||||
bodyCopy := new(bytes.Buffer)
|
||||
_, err := io.Copy(bodyCopy, c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
bodyData := bodyCopy.Bytes()
|
||||
err = json.Unmarshal(bodyData, &service)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = ioutil.NopCloser(bytes.NewReader(bodyData))
|
||||
switch options["captcha_type"] {
|
||||
case "normal":
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "recaptcha":
|
||||
reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
|
||||
c.Abort()
|
||||
break
|
||||
}
|
||||
|
||||
err = reCAPTCHA.Verify(service.CaptchaCode)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
|
||||
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "tcaptcha":
|
||||
credential := common.NewCredential(
|
||||
options["captcha_TCaptcha_SecretId"],
|
||||
options["captcha_TCaptcha_SecretKey"],
|
||||
)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.Endpoint = "captcha.tencentcloudapi.com"
|
||||
client, _ := captcha.NewClient(credential, "", cpf)
|
||||
request := captcha.NewDescribeCaptchaResultRequest()
|
||||
request.CaptchaType = common.Uint64Ptr(9)
|
||||
appid, _ := strconv.Atoi(options["captcha_TCaptcha_CaptchaAppId"])
|
||||
request.CaptchaAppId = common.Uint64Ptr(uint64(appid))
|
||||
request.AppSecretKey = common.StringPtr(options["captcha_TCaptcha_AppSecretKey"])
|
||||
request.Ticket = common.StringPtr(service.Ticket)
|
||||
request.Randstr = common.StringPtr(service.Randstr)
|
||||
request.UserIp = common.StringPtr(c.ClientIP())
|
||||
response, err := client.DescribeCaptchaResult(request)
|
||||
if err != nil {
|
||||
util.Log().Warning("TCaptcha 验证错误, %s", err)
|
||||
c.Abort()
|
||||
break
|
||||
}
|
||||
|
||||
if *response.Response.CaptchaCode != int64(1) {
|
||||
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
177
middleware/captcha_test.go
Normal file
177
middleware/captcha_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type errReader int
|
||||
|
||||
func (errReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("test error")
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_General(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 未启用验证码
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "0",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// body 无法读取
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", errReader(1))
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// body JSON 解析失败
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("123"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Normal(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "normal",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
Session("233")(c)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Recaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 无法初始化reCaptcha实例
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "233",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Tcaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证出错
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "tcaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
@@ -186,17 +186,17 @@ func (file *File) Rename(new string) error {
|
||||
|
||||
// UpdatePicInfo 更新文件的图像信息
|
||||
func (file *File) UpdatePicInfo(value string) error {
|
||||
return DB.Model(&file).Update("pic_info", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("pic_info", value).Error
|
||||
}
|
||||
|
||||
// UpdateSize 更新文件的大小信息
|
||||
func (file *File) UpdateSize(value uint64) error {
|
||||
return DB.Model(&file).Update("size", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value).Error
|
||||
}
|
||||
|
||||
// UpdateSourceName 更新文件的源文件名
|
||||
func (file *File) UpdateSourceName(value string) error {
|
||||
return DB.Model(&file).Update("source_name", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -44,6 +44,26 @@ func (folder *Folder) GetChild(name string) (*Folder, error) {
|
||||
return &resFolder, err
|
||||
}
|
||||
|
||||
// TraceRoot 向上递归查找父目录
|
||||
func (folder *Folder) TraceRoot() error {
|
||||
if folder.ParentID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parentFolder Folder
|
||||
err := DB.
|
||||
Where("id = ? AND owner_id = ?", folder.ParentID, folder.OwnerID).
|
||||
First(&parentFolder).Error
|
||||
|
||||
if err == nil {
|
||||
err := parentFolder.TraceRoot()
|
||||
folder.Position = path.Join(parentFolder.Position, parentFolder.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChildFolder 查找子目录
|
||||
func (folder *Folder) GetChildFolder() ([]Folder, error) {
|
||||
var folders []Folder
|
||||
|
||||
@@ -530,3 +530,37 @@ func TestFolder_FileInfoInterface(t *testing.T) {
|
||||
asserts.True(folder.IsDir())
|
||||
asserts.Equal("/test", folder.GetPosition())
|
||||
}
|
||||
|
||||
func TestTraceRoot(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
var parentId uint
|
||||
parentId = 5
|
||||
folder := Folder{
|
||||
ParentID: &parentId,
|
||||
OwnerID: 1,
|
||||
Name: "test_name",
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/"))
|
||||
asserts.NoError(folder.TraceRoot())
|
||||
asserts.Equal("/parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 出现错误
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnError(errors.New("error"))
|
||||
asserts.Error(folder.TraceRoot())
|
||||
asserts.Equal("parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ func addDefaultSettings() {
|
||||
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
|
||||
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
|
||||
{Name: "aria2_call_timeout", Value: `5`, Type: "timeout"},
|
||||
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
|
||||
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
|
||||
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
|
||||
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
|
||||
@@ -148,6 +149,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "share_view_method", Value: "list", Type: "view"},
|
||||
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
|
||||
{Name: "authn_enabled", Value: "0", Type: "authn"},
|
||||
{Name: "captcha_type", Value: "normal", Type: "captcha"},
|
||||
{Name: "captcha_height", Value: "60", Type: "captcha"},
|
||||
{Name: "captcha_width", Value: "240", Type: "captcha"},
|
||||
{Name: "captcha_mode", Value: "3", Type: "captcha"},
|
||||
@@ -159,9 +161,12 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"},
|
||||
{Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"},
|
||||
{Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"},
|
||||
{Name: "captcha_IsUseReCaptcha", Value: "0", Type: "captcha"},
|
||||
{Name: "captcha_ReCaptchaKey", Value: "defaultKey", Type: "captcha"},
|
||||
{Name: "captcha_ReCaptchaSecret", Value: "defaultSecret", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_CaptchaAppId", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_AppSecretKey", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_SecretId", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_SecretKey", Value: "", Type: "captcha"},
|
||||
{Name: "thumb_width", Value: "400", Type: "thumb"},
|
||||
{Name: "thumb_height", Value: "300", Type: "thumb"},
|
||||
{Name: "pwa_small_icon", Value: "/static/img/favicon.ico", Type: "pwa"},
|
||||
|
||||
@@ -51,6 +51,8 @@ type PolicyOption struct {
|
||||
OdRedirect string `json:"od_redirect,omitempty"`
|
||||
// OdProxy Onedrive 反代地址
|
||||
OdProxy string `json:"od_proxy,omitempty"`
|
||||
// OdDriver OneDrive 驱动器定位符
|
||||
OdDriver string `json:"od_driver,omitempty"`
|
||||
// Region 区域代码
|
||||
Region string `json:"region,omitempty"`
|
||||
// ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段
|
||||
@@ -268,9 +270,8 @@ func (policy *Policy) GetUploadURL() string {
|
||||
return server.ResolveReference(controller).String()
|
||||
}
|
||||
|
||||
// UpdateAccessKey 更新 AccessKey
|
||||
func (policy *Policy) UpdateAccessKey(key string) error {
|
||||
policy.AccessKey = key
|
||||
// SaveAndClearCache 更新并清理缓存
|
||||
func (policy *Policy) SaveAndClearCache() error {
|
||||
err := DB.Save(policy).Error
|
||||
policy.ClearCache()
|
||||
return err
|
||||
|
||||
@@ -257,7 +257,8 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := policy.UpdateAccessKey("123")
|
||||
policy.AccessKey = "123"
|
||||
err := policy.SaveAndClearCache()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
35
models/scripts/reset.go
Normal file
35
models/scripts/reset.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
type ResetAdminPassword int
|
||||
|
||||
func init() {
|
||||
register("ResetAdminPassword", ResetAdminPassword(0))
|
||||
}
|
||||
|
||||
// Run 运行脚本从社区版升级至 Pro 版
|
||||
func (script ResetAdminPassword) Run(ctx context.Context) {
|
||||
// 查找用户
|
||||
user, err := model.GetUserByID(1)
|
||||
if err != nil {
|
||||
util.Log().Panic("初始管理员用户不存在, %s", err)
|
||||
}
|
||||
|
||||
// 生成密码
|
||||
password := util.RandStringRunes(8)
|
||||
|
||||
// 更改为新密码
|
||||
user.SetPassword(password)
|
||||
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
|
||||
util.Log().Panic("密码更改失败, %s", err)
|
||||
}
|
||||
|
||||
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
|
||||
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password))
|
||||
}
|
||||
50
models/scripts/reset_test.go
Normal file
50
models/scripts/reset_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResetAdminPassword_Run(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
script := ResetAdminPassword(0)
|
||||
|
||||
// 初始用户不存在
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}))
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 密码更新失败
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
asserts.NotPanics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
@@ -139,6 +139,13 @@ func GetActiveUserByOpenID(openid string) (User, error) {
|
||||
|
||||
// GetUserByEmail 用Email获取用户
|
||||
func GetUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user)
|
||||
return user, result.Error
|
||||
}
|
||||
|
||||
// GetActiveUserByEmail 用Email获取可登录用户
|
||||
func GetActiveUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user)
|
||||
return user, result.Error
|
||||
|
||||
@@ -352,10 +352,20 @@ func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
func TestGetActiveUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetActiveUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package conf
|
||||
|
||||
// BackendVersion 当前后端版本号
|
||||
var BackendVersion = "3.2.1"
|
||||
var BackendVersion = "3.3.1"
|
||||
|
||||
// RequiredDBVersion 与当前版本匹配的数据库版本
|
||||
var RequiredDBVersion = "3.2.0"
|
||||
var RequiredDBVersion = "3.3.1"
|
||||
|
||||
// RequiredStaticVersion 与当前版本匹配的静态资源版本
|
||||
var RequiredStaticVersion = "3.2.1"
|
||||
var RequiredStaticVersion = "3.3.1"
|
||||
|
||||
// IsPro 是否为Pro版本
|
||||
var IsPro = "false"
|
||||
|
||||
@@ -100,6 +100,14 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
defer file.Close()
|
||||
dst = util.RelativePath(filepath.FromSlash(dst))
|
||||
|
||||
// 如果禁止了 Overwrite,则检查是否有重名冲突
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
if util.Exists(dst) {
|
||||
util.Log().Warning("物理同名文件已存在或不可用: %s", dst)
|
||||
return errors.New("物理同名文件已存在或不可用")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果目标目录不存在,创建
|
||||
basePath := filepath.Dir(dst)
|
||||
if !util.Exists(basePath) {
|
||||
@@ -130,11 +138,14 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
var retErr error
|
||||
|
||||
for _, value := range files {
|
||||
err := os.Remove(util.RelativePath(filepath.FromSlash(value)))
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除文件,%s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
filePath := util.RelativePath(filepath.FromSlash(value))
|
||||
if util.Exists(filePath) {
|
||||
err := os.Remove(filePath)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除文件,%s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试删除文件的缩略图(如果有)
|
||||
|
||||
@@ -2,13 +2,6 @@ package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
@@ -16,12 +9,19 @@ import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandler_Put(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
handler := Driver{}
|
||||
ctx := context.Background()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
os.Remove(util.RelativePath("test/test/txt"))
|
||||
|
||||
testCases := []struct {
|
||||
file io.ReadCloser
|
||||
@@ -33,6 +33,11 @@ func TestHandler_Put(t *testing.T) {
|
||||
dst: "test/test/txt",
|
||||
err: false,
|
||||
},
|
||||
{
|
||||
file: ioutil.NopCloser(strings.NewReader("test input file")),
|
||||
dst: "test/test/txt",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
file: ioutil.NopCloser(strings.NewReader("test input file")),
|
||||
dst: "/notexist:/S.TXT",
|
||||
@@ -55,24 +60,34 @@ func TestHandler_Delete(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
handler := Driver{}
|
||||
ctx := context.Background()
|
||||
filePath := util.RelativePath("test.file")
|
||||
|
||||
file, err := os.Create(util.RelativePath("test.file"))
|
||||
file, err := os.Create(filePath)
|
||||
asserts.NoError(err)
|
||||
_ = file.Close()
|
||||
list, err := handler.Delete(ctx, []string{"test.file"})
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
file, err = os.Create(util.RelativePath("test.file"))
|
||||
asserts.NoError(err)
|
||||
file, err = os.Create(filePath)
|
||||
_ = file.Close()
|
||||
file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0))
|
||||
asserts.NoError(err)
|
||||
list, err = handler.Delete(ctx, []string{"test.file", "test.notexist"})
|
||||
asserts.Equal([]string{"test.notexist"}, list)
|
||||
asserts.Error(err)
|
||||
file.Close()
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
list, err = handler.Delete(ctx, []string{"test.notexist"})
|
||||
asserts.Equal([]string{"test.notexist"}, list)
|
||||
asserts.Error(err)
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
|
||||
file, err = os.Create(filePath)
|
||||
asserts.NoError(err)
|
||||
list, err = handler.Delete(ctx, []string{"test.file"})
|
||||
_ = file.Close()
|
||||
asserts.Equal([]string{}, list)
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
func TestHandler_Get(t *testing.T) {
|
||||
|
||||
@@ -53,12 +53,23 @@ func (err RespError) Error() string {
|
||||
return err.APIError.Message
|
||||
}
|
||||
|
||||
func (client *Client) getRequestURL(api string) string {
|
||||
func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(client.Endpoints.EndpointURL)
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
base.Path = path.Join(base.Path, api)
|
||||
|
||||
if options.useDriverResource {
|
||||
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
|
||||
} else {
|
||||
base.Path = path.Join(base.Path, api)
|
||||
}
|
||||
|
||||
return base.String()
|
||||
}
|
||||
|
||||
@@ -67,9 +78,9 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
var requestURL string
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
if dst == "" {
|
||||
requestURL = client.getRequestURL("me/drive/root/children")
|
||||
requestURL = client.getRequestURL("root/children")
|
||||
} else {
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst + ":/children")
|
||||
requestURL = client.getRequestURL("root:/" + dst + ":/children")
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
|
||||
@@ -103,10 +114,10 @@ func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo
|
||||
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
var requestURL string
|
||||
if id != "" {
|
||||
requestURL = client.getRequestURL("/me/drive/items/" + id)
|
||||
requestURL = client.getRequestURL("items/" + id)
|
||||
} else {
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
requestURL = client.getRequestURL("me/drive/root:/" + dst)
|
||||
requestURL = client.getRequestURL("root:/" + dst)
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
|
||||
@@ -129,14 +140,13 @@ func (client *Client) Meta(ctx context.Context, id string, path string) (*FileIn
|
||||
|
||||
// CreateUploadSession 创建分片上传会话
|
||||
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/createUploadSession")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
|
||||
body := map[string]map[string]interface{}{
|
||||
"item": {
|
||||
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
|
||||
@@ -161,6 +171,33 @@ func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts
|
||||
return uploadSession.UploadURL, nil
|
||||
}
|
||||
|
||||
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
|
||||
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
siteUrlParsed, err := url.Parse(siteUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hostName := siteUrlParsed.Hostname()
|
||||
relativePath := strings.Trim(siteUrlParsed.Path, "/")
|
||||
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
|
||||
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if reqErr != nil {
|
||||
return "", reqErr
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
siteInfo Site
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
return siteInfo.ID, nil
|
||||
}
|
||||
|
||||
// GetUploadSessionStatus 查询上传会话状态
|
||||
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
|
||||
@@ -220,15 +257,21 @@ func (client *Client) UploadChunk(ctx context.Context, uploadURL string, chunk *
|
||||
|
||||
// Upload 上传文件
|
||||
func (client *Client) Upload(ctx context.Context, dst string, size int, file io.Reader) error {
|
||||
// 决定是否覆盖文件
|
||||
overwrite := "replace"
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = "fail"
|
||||
}
|
||||
|
||||
// 小文件,使用简单上传接口上传
|
||||
if size <= int(SmallFileSize) {
|
||||
_, err := client.SimpleUpload(ctx, dst, file, int64(size))
|
||||
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
|
||||
return err
|
||||
}
|
||||
|
||||
// 大文件,进行分片
|
||||
// 创建上传会话
|
||||
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior("replace"))
|
||||
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -287,9 +330,15 @@ func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string)
|
||||
}
|
||||
|
||||
// SimpleUpload 上传小文件到dst
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64) (*UploadResult, error) {
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("me/drive/root:/" + dst + ":/content")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/content")
|
||||
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
|
||||
|
||||
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
|
||||
request.WithTimeout(time.Duration(150)*time.Second),
|
||||
@@ -303,7 +352,7 @@ func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Read
|
||||
retried++
|
||||
util.Log().Debug("文件[%s]上传失败[%s],5秒钟后重试", dst, err)
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
return client.SimpleUpload(context.WithValue(ctx, fsctx.RetryCtx, retried), dst, body, size)
|
||||
return client.SimpleUpload(context.WithValue(ctx, fsctx.RetryCtx, retried), dst, body, size, opts...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -345,7 +394,8 @@ func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string,
|
||||
// 由于API限制,最多删除20个
|
||||
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
body := client.makeBatchDeleteRequestsBody(dst)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch"), body, 200)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
|
||||
WithDriverResource(false)), body, 200)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
@@ -370,7 +420,7 @@ func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error
|
||||
func getDeleteFailed(res *BatchResponses) []string {
|
||||
var failed = make([]string, 0, len(res.Responses))
|
||||
for _, v := range res.Responses {
|
||||
if v.Status != 204 {
|
||||
if v.Status != 204 && v.Status != 404 {
|
||||
failed = append(failed, v.ID)
|
||||
}
|
||||
}
|
||||
@@ -384,7 +434,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
}
|
||||
for i, v := range files {
|
||||
v = strings.TrimPrefix(v, "/")
|
||||
filePath, _ := url.Parse("/me/drive/root:/")
|
||||
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
|
||||
filePath.Path = path.Join(filePath.Path, v)
|
||||
req.Requests[i] = BatchRequest{
|
||||
ID: v,
|
||||
@@ -400,17 +450,7 @@ func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
// GetThumbURL 获取给定尺寸的缩略图URL
|
||||
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
var (
|
||||
cropOption string
|
||||
requestURL string
|
||||
)
|
||||
if client.Endpoints.isInChina {
|
||||
cropOption = "large"
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails/0") + "/" + cropOption
|
||||
} else {
|
||||
cropOption = fmt.Sprintf("c%dx%d_Crop", w, h)
|
||||
requestURL = client.getRequestURL("me/drive/root:/"+dst+":/thumbnails") + "?select=" + cropOption
|
||||
}
|
||||
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if err != nil {
|
||||
@@ -431,7 +471,7 @@ func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (s
|
||||
}
|
||||
|
||||
if len(thumbRes.Value) == 1 {
|
||||
if res, ok := thumbRes.Value[0][cropOption]; ok {
|
||||
if res, ok := thumbRes.Value[0]["large"]; ok {
|
||||
return res.(map[string]interface{})["url"].(string), nil
|
||||
}
|
||||
}
|
||||
@@ -456,7 +496,7 @@ func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size ui
|
||||
case <-time.After(time.Duration(ttl) * time.Second):
|
||||
// 上传会话到期,仍未完成上传,创建占位符
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
@@ -504,7 +544,7 @@ func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size ui
|
||||
// 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传,
|
||||
// 所以上传一个空文件占位,阻止客户端上传
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
@@ -166,6 +167,82 @@ func TestClient_GetRequestURL(t *testing.T) {
|
||||
client.Endpoints.EndpointURL = string([]byte{0x7f})
|
||||
asserts.Equal("", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123"))
|
||||
}
|
||||
|
||||
// 不使用DriverResource
|
||||
{
|
||||
client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
|
||||
asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetSiteIDByURL(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
client, _ := NewClient(&model.Policy{})
|
||||
client.Credential.AccessToken = "AccessToken"
|
||||
|
||||
// 请求失败
|
||||
{
|
||||
client.Credential.ExpiresIn = 0
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
|
||||
}
|
||||
|
||||
// 返回未知响应
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`???`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
}
|
||||
|
||||
// 返回正常
|
||||
{
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
clientMock := ClientMock{}
|
||||
clientMock.On(
|
||||
"Request",
|
||||
"GET",
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
testMock.Anything,
|
||||
).Return(&request.Response{
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
|
||||
clientMock.AssertExpectations(t)
|
||||
asserts.NoError(err)
|
||||
asserts.NotEmpty(res)
|
||||
asserts.Equal("123321", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Meta(t *testing.T) {
|
||||
@@ -499,11 +576,12 @@ func TestClient_Upload(t *testing.T) {
|
||||
client, _ := NewClient(&model.Policy{})
|
||||
client.Credential.AccessToken = "AccessToken"
|
||||
client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
|
||||
// 小文件,简单上传,失败
|
||||
{
|
||||
client.Credential.ExpiresIn = 0
|
||||
err := client.Upload(context.Background(), "123.jpg", 3, strings.NewReader("123"))
|
||||
err := client.Upload(ctx, "123.jpg", 3, strings.NewReader("123"))
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
@@ -888,7 +966,7 @@ func TestClient_GetThumbURL(t *testing.T) {
|
||||
Err: nil,
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"c1x1_Crop":{"url":"thumb"}}]}`)),
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"large":{"url":"thumb"}}]}`)),
|
||||
},
|
||||
})
|
||||
client.Request = clientMock
|
||||
|
||||
@@ -37,14 +37,16 @@ type Endpoints struct {
|
||||
OAuthEndpoints *oauthEndpoint
|
||||
EndpointURL string // 接口请求的基URL
|
||||
isInChina bool // 是否为世纪互联
|
||||
DriverResource string // 要使用的驱动器
|
||||
}
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *model.Policy) (*Client, error) {
|
||||
client := &Client{
|
||||
Endpoints: &Endpoints{
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
DriverResource: policy.OptionsSerialized.OdDriver,
|
||||
},
|
||||
Credential: &Credential{
|
||||
RefreshToken: policy.AccessKey,
|
||||
@@ -56,6 +58,10 @@ func NewClient(policy *model.Policy) (*Client, error) {
|
||||
Request: request.HTTPClient{},
|
||||
}
|
||||
|
||||
if client.Endpoints.DriverResource == "" {
|
||||
client.Endpoints.DriverResource = "me/drive"
|
||||
}
|
||||
|
||||
oauthBase := client.getOAuthEndpoint()
|
||||
if oauthBase == nil {
|
||||
return nil, ErrAuthEndpoint
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
@@ -152,8 +153,26 @@ func (handler Driver) Source(
|
||||
isDownload bool,
|
||||
speed int,
|
||||
) (string, error) {
|
||||
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
|
||||
// 如果是永久链接,则返回签名后的中转外链
|
||||
if ttl == 0 {
|
||||
signedURI, err := auth.SignURI(
|
||||
auth.General,
|
||||
fmt.Sprintf("/api/v3/file/source/%d/%s", file.ID, file.Name),
|
||||
ttl,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return baseURL.ResolveReference(signedURI).String(), nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 尝试从缓存中查找
|
||||
if cachedURL, ok := cache.Get(fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)); ok {
|
||||
if cachedURL, ok := cache.Get(cacheKey); ok {
|
||||
return handler.replaceSourceHost(cachedURL.(string))
|
||||
}
|
||||
|
||||
@@ -162,7 +181,7 @@ func (handler Driver) Source(
|
||||
if err == nil {
|
||||
// 写入新的缓存
|
||||
cache.Set(
|
||||
fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path),
|
||||
cacheKey,
|
||||
res.DownloadURL,
|
||||
model.GetIntSetting("onedrive_source_timeout", 1800),
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -134,7 +136,7 @@ func TestDriver_Source(t *testing.T) {
|
||||
|
||||
// 失败
|
||||
{
|
||||
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0)
|
||||
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
|
||||
asserts.Error(err)
|
||||
asserts.Empty(res)
|
||||
}
|
||||
@@ -143,13 +145,28 @@ func TestDriver_Source(t *testing.T) {
|
||||
{
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
handler.Client.Credential.AccessToken = "1"
|
||||
cache.Set("onedrive_source_0_123.jpg", "res", 0)
|
||||
cache.Set("onedrive_source_0_123.jpg", "res", 1)
|
||||
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0)
|
||||
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
|
||||
asserts.NoError(err)
|
||||
asserts.Equal("res", res)
|
||||
}
|
||||
|
||||
// 命中缓存 上下文存在文件 成功
|
||||
{
|
||||
file := model.File{}
|
||||
file.ID = 1
|
||||
file.UpdatedAt = time.Now()
|
||||
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
handler.Client.Credential.AccessToken = "1"
|
||||
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
|
||||
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 1, true, 0)
|
||||
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
|
||||
asserts.NoError(err)
|
||||
asserts.Equal("res", res)
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
@@ -169,10 +186,25 @@ func TestDriver_Source(t *testing.T) {
|
||||
})
|
||||
handler.Client.Request = clientMock
|
||||
handler.Client.Credential.AccessToken = "1"
|
||||
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0)
|
||||
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
|
||||
asserts.NoError(err)
|
||||
asserts.Equal("123321", res)
|
||||
}
|
||||
|
||||
// 成功 永久直链
|
||||
{
|
||||
file := model.File{}
|
||||
file.ID = 1
|
||||
file.Name = "123.jpg"
|
||||
file.UpdatedAt = time.Now()
|
||||
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
|
||||
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
|
||||
auth.General = auth.HMACAuth{}
|
||||
handler.Client.Credential.AccessToken = "1"
|
||||
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
|
||||
asserts.NoError(err)
|
||||
asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriver_List(t *testing.T) {
|
||||
|
||||
@@ -160,7 +160,8 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
|
||||
client.Credential = credential
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
client.Policy.UpdateAccessKey(credential.RefreshToken)
|
||||
client.Policy.AccessKey = credential.RefreshToken
|
||||
client.Policy.SaveAndClearCache()
|
||||
|
||||
// 更新缓存
|
||||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||||
|
||||
@@ -8,11 +8,12 @@ type Option interface {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
useDriverResource bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
@@ -38,13 +39,21 @@ func WithConflictBehavior(t string) Option {
|
||||
})
|
||||
}
|
||||
|
||||
// WithConflictBehavior 设置文件重名后的处理方式
|
||||
func WithDriverResource(t bool) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.useDriverResource = t
|
||||
})
|
||||
}
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
conflictBehavior: "fail",
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
conflictBehavior: "fail",
|
||||
useDriverResource: true,
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +131,15 @@ type OAuthError struct {
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
}
|
||||
|
||||
// Site SharePoint 站点信息
|
||||
type Site struct {
|
||||
Description string `json:"description"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
WebUrl string `json:"webUrl"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
@@ -235,8 +235,15 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
// 凭证有效期
|
||||
credentialTTL := model.GetIntSetting("upload_credential_timeout", 3600)
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := true
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = false
|
||||
}
|
||||
|
||||
options := []oss.Option{
|
||||
oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)),
|
||||
oss.ForbidOverWrite(!overwrite),
|
||||
}
|
||||
|
||||
// 上传文件
|
||||
|
||||
@@ -265,10 +265,11 @@ func TestDriver_Put(t *testing.T) {
|
||||
},
|
||||
}
|
||||
cache.Set("setting_upload_credential_timeout", "3600", 0)
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
|
||||
// 失败
|
||||
{
|
||||
err := handler.Put(context.Background(), ioutil.NopCloser(strings.NewReader("123")), "/123.txt", 3)
|
||||
err := handler.Put(ctx, ioutil.NopCloser(strings.NewReader("123")), "/123.txt", 3)
|
||||
asserts.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
if err != nil {
|
||||
failed := make([]string, 0, len(rets))
|
||||
for k, ret := range rets {
|
||||
if ret.Code != 200 {
|
||||
if ret.Code != 200 && ret.Code != 612 {
|
||||
failed = append(failed, files[k])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,6 +155,13 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 决定是否要禁用文件覆盖
|
||||
overwrite := "true"
|
||||
if ctx.Value(fsctx.DisableOverwrite) != nil {
|
||||
overwrite = "false"
|
||||
}
|
||||
|
||||
// 上传文件
|
||||
resp, err := handler.Client.Request(
|
||||
"POST",
|
||||
@@ -164,6 +171,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
|
||||
"Authorization": {credential.Token},
|
||||
"X-Policy": {credential.Policy},
|
||||
"X-FileName": {fileName},
|
||||
"X-Overwrite": {overwrite},
|
||||
}),
|
||||
request.WithContentLength(int64(size)),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
@@ -321,7 +329,8 @@ func (handler Driver) getUploadCredential(ctx context.Context, policy serializer
|
||||
// 签名上传策略
|
||||
uploadRequest, _ := http.NewRequest("POST", "/api/v3/slave/upload", nil)
|
||||
uploadRequest.Header = map[string][]string{
|
||||
"X-Policy": {policyEncoded},
|
||||
"X-Policy": {policyEncoded},
|
||||
"X-Overwrite": {"false"},
|
||||
}
|
||||
auth.SignRequest(handler.AuthInstance, uploadRequest, TTL)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestHandler_Token(t *testing.T) {
|
||||
},
|
||||
AuthInstance: auth.HMACAuth{},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
auth.General = auth.HMACAuth{SecretKey: []byte("test")}
|
||||
|
||||
// 成功
|
||||
@@ -49,6 +49,7 @@ func TestHandler_Token(t *testing.T) {
|
||||
asserts.Equal(true, policy.AutoRename)
|
||||
asserts.Equal("dir", policy.SavePath)
|
||||
asserts.Equal("file", policy.FileName)
|
||||
asserts.Equal("file", policy.FileName)
|
||||
asserts.Equal([]string{"txt"}, policy.AllowedExtension)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
@@ -60,7 +60,7 @@ func (handler *Driver) InitS3Client() error {
|
||||
Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
|
||||
Endpoint: &handler.Policy.Server,
|
||||
Region: &handler.Policy.OptionsSerialized.Region,
|
||||
S3ForcePathStyle: aws.Bool(false),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -229,53 +229,35 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
|
||||
return files, err
|
||||
}
|
||||
|
||||
var (
|
||||
failed = make([]string, 0, len(files))
|
||||
lastErr error
|
||||
currentIndex = 0
|
||||
indexLock sync.Mutex
|
||||
failedLock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
routineNum = 4
|
||||
)
|
||||
wg.Add(routineNum)
|
||||
failed := make([]string, 0, len(files))
|
||||
deleted := make([]string, 0, len(files))
|
||||
|
||||
// S3不支持批量操作,这里开四个协程并行操作
|
||||
for i := 0; i < routineNum; i++ {
|
||||
go func() {
|
||||
for {
|
||||
// 取得待删除文件
|
||||
indexLock.Lock()
|
||||
if currentIndex >= len(files) {
|
||||
// 所有文件处理完成
|
||||
wg.Done()
|
||||
indexLock.Unlock()
|
||||
return
|
||||
}
|
||||
path := files[currentIndex]
|
||||
currentIndex++
|
||||
indexLock.Unlock()
|
||||
|
||||
// 发送异步删除请求
|
||||
_, err := handler.svc.DeleteObject(
|
||||
&s3.DeleteObjectInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &path,
|
||||
})
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
failedLock.Lock()
|
||||
lastErr = err
|
||||
failed = append(failed, path)
|
||||
failedLock.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
keys := make([]*s3.ObjectIdentifier, 0, len(files))
|
||||
for _, file := range files {
|
||||
filePath := file
|
||||
keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return failed, lastErr
|
||||
// 发送异步删除请求
|
||||
res, err := handler.svc.DeleteObjects(
|
||||
&s3.DeleteObjectsInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Delete: &s3.Delete{
|
||||
Objects: keys,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 统计未删除的文件
|
||||
for _, deleteRes := range res.Deleted {
|
||||
deleted = append(deleted, *deleteRes.Key)
|
||||
}
|
||||
failed = util.SliceDifference(failed, deleted)
|
||||
|
||||
return failed, nil
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ func (fs *FileSystem) Preview(ctx context.Context, id uint, isText bool) (*respo
|
||||
|
||||
// 否则重定向到签名的预览URL
|
||||
ttl := model.GetIntSetting("preview_timeout", 60)
|
||||
previewURL, err := fs.signURL(ctx, &fs.FileTarget[0], int64(ttl), false)
|
||||
previewURL, err := fs.SignURL(ctx, &fs.FileTarget[0], int64(ttl), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -234,7 +234,7 @@ func (fs *FileSystem) GetDownloadURL(ctx context.Context, id uint, timeout strin
|
||||
|
||||
// 生成下載地址
|
||||
ttl := model.GetIntSetting(timeout, 60)
|
||||
source, err := fs.signURL(
|
||||
source, err := fs.SignURL(
|
||||
ctx,
|
||||
fileTarget,
|
||||
int64(ttl),
|
||||
@@ -264,7 +264,7 @@ func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error
|
||||
)
|
||||
}
|
||||
|
||||
source, err := fs.signURL(ctx, &fs.FileTarget[0], 0, false)
|
||||
source, err := fs.SignURL(ctx, &fs.FileTarget[0], 0, false)
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeNotSet, "无法获取外链", err)
|
||||
}
|
||||
@@ -272,7 +272,8 @@ func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error
|
||||
return source, nil
|
||||
}
|
||||
|
||||
func (fs *FileSystem) signURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) {
|
||||
// SignURL 签名文件原始 URL
|
||||
func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) {
|
||||
fs.FileTarget = []model.File{*file}
|
||||
ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file)
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ const (
|
||||
ShareKeyCtx
|
||||
// LimitParentCtx 限制父目录
|
||||
LimitParentCtx
|
||||
// IgnoreConflictCtx 忽略重名冲突
|
||||
IgnoreConflictCtx
|
||||
// IgnoreDirectoryConflictCtx 忽略目录重名冲突
|
||||
IgnoreDirectoryConflictCtx
|
||||
// RetryCtx 失败重试次数
|
||||
RetryCtx
|
||||
// ForceUsePublicEndpointCtx 强制使用公网 Endpoint
|
||||
@@ -39,4 +39,6 @@ const (
|
||||
CancelFuncCtx
|
||||
// ValidateCapacityOnceCtx 限定归还容量的操作只執行一次
|
||||
ValidateCapacityOnceCtx
|
||||
// 禁止上传时同名覆盖操作
|
||||
DisableOverwrite
|
||||
)
|
||||
|
||||
@@ -39,8 +39,8 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR
|
||||
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
|
||||
}
|
||||
|
||||
// 出错时重新生成缩略图
|
||||
if err != nil {
|
||||
// 本地存储策略出错时重新生成缩略图
|
||||
if err != nil && fs.Policy.Type == "local" {
|
||||
fs.GenerateThumbnail(ctx, &fs.FileTarget[0])
|
||||
}
|
||||
|
||||
|
||||
@@ -403,8 +403,8 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*mo
|
||||
isExist, parent := fs.IsPathExist(base)
|
||||
if !isExist {
|
||||
// 递归创建父目录
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreConflictCtx, true)
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
|
||||
}
|
||||
newParent, err := fs.CreateDirectory(ctx, base)
|
||||
if err != nil {
|
||||
@@ -427,7 +427,7 @@ func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*mo
|
||||
_, err := newFolder.Create()
|
||||
|
||||
if err != nil {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
return nil, ErrFolderExisted
|
||||
}
|
||||
|
||||
|
||||
23
pkg/serializer/explorer.go
Normal file
23
pkg/serializer/explorer.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package serializer
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(ObjectProps{})
|
||||
}
|
||||
|
||||
// ObjectProps 文件、目录对象的详细属性信息
|
||||
type ObjectProps struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Policy string `json:"policy"`
|
||||
Size uint64 `json:"size"`
|
||||
ChildFolderNum int `json:"child_folder_num"`
|
||||
ChildFileNum int `json:"child_file_num"`
|
||||
Path string `json:"path"`
|
||||
|
||||
QueryDate time.Time `json:"query_date"`
|
||||
}
|
||||
@@ -4,20 +4,21 @@ import model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
|
||||
// SiteConfig 站点全局设置序列
|
||||
type SiteConfig struct {
|
||||
SiteName string `json:"title"`
|
||||
SiteICPId string `json:"siteICPId"`
|
||||
LoginCaptcha bool `json:"loginCaptcha"`
|
||||
RegCaptcha bool `json:"regCaptcha"`
|
||||
ForgetCaptcha bool `json:"forgetCaptcha"`
|
||||
EmailActive bool `json:"emailActive"`
|
||||
Themes string `json:"themes"`
|
||||
DefaultTheme string `json:"defaultTheme"`
|
||||
HomepageViewMethod string `json:"home_view_method"`
|
||||
ShareViewMethod string `json:"share_view_method"`
|
||||
Authn bool `json:"authn"`
|
||||
User User `json:"user"`
|
||||
UseReCaptcha bool `json:"captcha_IsUseReCaptcha"`
|
||||
ReCaptchaKey string `json:"captcha_ReCaptchaKey"`
|
||||
SiteName string `json:"title"`
|
||||
SiteICPId string `json:"siteICPId"`
|
||||
LoginCaptcha bool `json:"loginCaptcha"`
|
||||
RegCaptcha bool `json:"regCaptcha"`
|
||||
ForgetCaptcha bool `json:"forgetCaptcha"`
|
||||
EmailActive bool `json:"emailActive"`
|
||||
Themes string `json:"themes"`
|
||||
DefaultTheme string `json:"defaultTheme"`
|
||||
HomepageViewMethod string `json:"home_view_method"`
|
||||
ShareViewMethod string `json:"share_view_method"`
|
||||
Authn bool `json:"authn"`
|
||||
User User `json:"user"`
|
||||
ReCaptchaKey string `json:"captcha_ReCaptchaKey"`
|
||||
CaptchaType string `json:"captcha_type"`
|
||||
TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"`
|
||||
}
|
||||
|
||||
type task struct {
|
||||
@@ -64,20 +65,21 @@ func BuildSiteConfig(settings map[string]string, user *model.User) Response {
|
||||
}
|
||||
res := Response{
|
||||
Data: SiteConfig{
|
||||
SiteName: checkSettingValue(settings, "siteName"),
|
||||
SiteICPId: checkSettingValue(settings, "siteICPId"),
|
||||
LoginCaptcha: model.IsTrueVal(checkSettingValue(settings, "login_captcha")),
|
||||
RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")),
|
||||
ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")),
|
||||
EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")),
|
||||
Themes: checkSettingValue(settings, "themes"),
|
||||
DefaultTheme: checkSettingValue(settings, "defaultTheme"),
|
||||
HomepageViewMethod: checkSettingValue(settings, "home_view_method"),
|
||||
ShareViewMethod: checkSettingValue(settings, "share_view_method"),
|
||||
Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")),
|
||||
User: userRes,
|
||||
UseReCaptcha: model.IsTrueVal(checkSettingValue(settings, "captcha_IsUseReCaptcha")),
|
||||
ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"),
|
||||
SiteName: checkSettingValue(settings, "siteName"),
|
||||
SiteICPId: checkSettingValue(settings, "siteICPId"),
|
||||
LoginCaptcha: model.IsTrueVal(checkSettingValue(settings, "login_captcha")),
|
||||
RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")),
|
||||
ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")),
|
||||
EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")),
|
||||
Themes: checkSettingValue(settings, "themes"),
|
||||
DefaultTheme: checkSettingValue(settings, "defaultTheme"),
|
||||
HomepageViewMethod: checkSettingValue(settings, "home_view_method"),
|
||||
ShareViewMethod: checkSettingValue(settings, "share_view_method"),
|
||||
Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")),
|
||||
User: userRes,
|
||||
ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"),
|
||||
CaptchaType: checkSettingValue(settings, "captcha_type"),
|
||||
TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"),
|
||||
}}
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
)
|
||||
|
||||
// DecompressTask 文件压缩任务
|
||||
@@ -81,7 +82,12 @@ func (job *DecompressTask) Do() {
|
||||
}
|
||||
|
||||
job.TaskModel.SetProgress(DecompressingProgress)
|
||||
err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst)
|
||||
|
||||
// 禁止重名覆盖
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
|
||||
err = fs.Decompress(ctx, job.TaskProps.Src, job.TaskProps.Dst)
|
||||
if err != nil {
|
||||
job.SetErrorMsg("解压缩失败", err)
|
||||
return
|
||||
|
||||
@@ -102,7 +102,7 @@ func (job *ImportTask) Do() {
|
||||
|
||||
// 列取目录、对象
|
||||
job.TaskModel.SetProgress(ListingProgress)
|
||||
coxIgnoreConflict := context.WithValue(context.Background(), fsctx.IgnoreConflictCtx,
|
||||
coxIgnoreConflict := context.WithValue(context.Background(), fsctx.IgnoreDirectoryConflictCtx,
|
||||
true)
|
||||
objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
@@ -102,7 +103,8 @@ func (job *TransferTask) Do() {
|
||||
dst = path.Join(job.TaskProps.Dst, strings.TrimPrefix(src, trim))
|
||||
}
|
||||
|
||||
err = fs.UploadFromPath(context.Background(), file, dst)
|
||||
ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true)
|
||||
err = fs.UploadFromPath(ctx, file, dst)
|
||||
if err != nil {
|
||||
job.SetErrorMsg("文件转存失败", err)
|
||||
}
|
||||
|
||||
@@ -371,6 +371,9 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst
|
||||
fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
|
||||
fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity)
|
||||
fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity)
|
||||
|
||||
// 禁止覆盖
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
}
|
||||
|
||||
// 执行上传
|
||||
@@ -407,8 +410,8 @@ func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request, fs *filesy
|
||||
return http.StatusUnsupportedMediaType, nil
|
||||
}
|
||||
if strings.Contains(r.UserAgent(), "rclone") {
|
||||
if _, ok := ctx.Value(fsctx.IgnoreConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreConflictCtx, true)
|
||||
if _, ok := ctx.Value(fsctx.IgnoreDirectoryConflictCtx).(bool); !ok {
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
|
||||
}
|
||||
}
|
||||
if _, err := fs.CreateDirectory(ctx, reqPath); err != nil {
|
||||
|
||||
@@ -25,7 +25,7 @@ func AdminSummary(c *gin.Context) {
|
||||
// AdminNews 获取社区新闻
|
||||
func AdminNews(c *gin.Context) {
|
||||
r := request.HTTPClient{}
|
||||
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&", nil)
|
||||
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
|
||||
if res.Err == nil {
|
||||
io.Copy(c.Writer, res.Response.Body)
|
||||
}
|
||||
|
||||
@@ -88,6 +88,29 @@ func AnonymousGetContent(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// AnonymousPermLink 文件签名后的永久链接
|
||||
func AnonymousPermLink(c *gin.Context) {
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var service explorer.FileAnonymousGetService
|
||||
if err := c.ShouldBindUri(&service); err == nil {
|
||||
res := service.Source(ctx, c)
|
||||
// 是否需要重定向
|
||||
if res.Code == -302 {
|
||||
c.Redirect(302, res.Data.(string))
|
||||
return
|
||||
}
|
||||
// 是否有错误发生
|
||||
if res.Code != 0 {
|
||||
c.JSON(200, res)
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// GetSource 获取文件的外链地址
|
||||
func GetSource(c *gin.Context) {
|
||||
// 创建上下文
|
||||
@@ -319,6 +342,7 @@ func FileUploadStream(c *gin.Context) {
|
||||
|
||||
// 执行上传
|
||||
ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{})
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c)
|
||||
err = fs.Upload(uploadCtx, fileData)
|
||||
if err != nil {
|
||||
|
||||
@@ -66,3 +66,19 @@ func Rename(c *gin.Context) {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Rename 重命名文件或目录
|
||||
func GetProperty(c *gin.Context) {
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var service explorer.ItemPropertyService
|
||||
service.ID = c.Param("id")
|
||||
if err := c.ShouldBindQuery(&service); err == nil {
|
||||
res := service.GetProperty(ctx, c)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +24,9 @@ func SiteConfig(c *gin.Context) {
|
||||
"home_view_method",
|
||||
"share_view_method",
|
||||
"authn_enabled",
|
||||
"captcha_IsUseReCaptcha",
|
||||
"captcha_ReCaptchaKey",
|
||||
"captcha_type",
|
||||
"captcha_TCaptcha_CaptchaAppId",
|
||||
)
|
||||
|
||||
// 如果已登录,则同时返回用户信息和标签
|
||||
|
||||
@@ -71,6 +71,11 @@ func SlaveUpload(c *gin.Context) {
|
||||
fs.Use("AfterUpload", filesystem.SlaveAfterUpload)
|
||||
fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
|
||||
|
||||
// 是否允许覆盖
|
||||
if c.Request.Header.Get("X-Overwrite") == "false" {
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
}
|
||||
|
||||
// 执行上传
|
||||
err = fs.Upload(ctx, fileData)
|
||||
if err != nil {
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
// StartLoginAuthn 开始注册WebAuthn登录
|
||||
func StartLoginAuthn(c *gin.Context) {
|
||||
userName := c.Param("username")
|
||||
expectedUser, err := model.GetUserByEmail(userName)
|
||||
expectedUser, err := model.GetActiveUserByEmail(userName)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotFound, "用户不存在", err))
|
||||
return
|
||||
@@ -52,7 +52,7 @@ func StartLoginAuthn(c *gin.Context) {
|
||||
// FinishLoginAuthn 完成注册WebAuthn登录
|
||||
func FinishLoginAuthn(c *gin.Context) {
|
||||
userName := c.Param("username")
|
||||
expectedUser, err := model.GetUserByEmail(userName)
|
||||
expectedUser, err := model.GetActiveUserByEmail(userName)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err))
|
||||
return
|
||||
@@ -349,6 +349,8 @@ func UpdateOption(c *gin.Context) {
|
||||
subService = &user.DeleteWebAuthn{}
|
||||
case "theme":
|
||||
subService = &user.ThemeChose{}
|
||||
default:
|
||||
subService = &user.ChangerNick{}
|
||||
}
|
||||
|
||||
subErr = c.ShouldBindJSON(subService)
|
||||
|
||||
@@ -116,16 +116,17 @@ func InitMasterRouter() *gin.Engine {
|
||||
user := v3.Group("user")
|
||||
{
|
||||
// 用户登录
|
||||
user.POST("session", controllers.UserLogin)
|
||||
user.POST("session", middleware.CaptchaRequired("login_captcha"), controllers.UserLogin)
|
||||
// 用户注册
|
||||
user.POST("",
|
||||
middleware.IsFunctionEnabled("register_enabled"),
|
||||
middleware.CaptchaRequired("reg_captcha"),
|
||||
controllers.UserRegister,
|
||||
)
|
||||
// 用二步验证户登录
|
||||
user.POST("2fa", controllers.User2FALogin)
|
||||
// 发送密码重设邮件
|
||||
user.POST("reset", controllers.UserSendReset)
|
||||
user.POST("reset", middleware.CaptchaRequired("forget_captcha"), controllers.UserSendReset)
|
||||
// 通过邮件里的链接重设密码
|
||||
user.PATCH("reset", controllers.UserReset)
|
||||
// 邮件激活
|
||||
@@ -162,8 +163,10 @@ func InitMasterRouter() *gin.Engine {
|
||||
{
|
||||
file := sign.Group("file")
|
||||
{
|
||||
// 文件外链
|
||||
// 文件外链(直接输出文件数据)
|
||||
file.GET("get/:id/:name", controllers.AnonymousGetContent)
|
||||
// 文件外链(301跳转)
|
||||
file.GET("source/:id/:name", controllers.AnonymousPermLink)
|
||||
// 下載已经打包好的文件
|
||||
file.GET("archive/:id/archive.zip", controllers.DownloadArchive)
|
||||
// 下载文件
|
||||
@@ -510,6 +513,8 @@ func InitMasterRouter() *gin.Engine {
|
||||
object.POST("copy", controllers.Copy)
|
||||
// 重命名对象
|
||||
object.POST("rename", controllers.Rename)
|
||||
// 获取对象属性
|
||||
object.GET("property/:id", controllers.GetProperty)
|
||||
}
|
||||
|
||||
// 分享
|
||||
|
||||
@@ -2,6 +2,7 @@ package callback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OneDriveOauthService OneDrive 授权回调服务
|
||||
@@ -41,17 +43,42 @@ func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err)
|
||||
}
|
||||
|
||||
credential, err := client.ObtainToken(context.Background(), onedrive.WithCode(service.Code))
|
||||
credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code))
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err)
|
||||
}
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
if err := client.Policy.UpdateAccessKey(credential.RefreshToken); err != nil {
|
||||
client.Policy.AccessKey = credential.RefreshToken
|
||||
if err := client.Policy.SaveAndClearCache(); err != nil {
|
||||
return serializer.DBErr("无法更新 RefreshToken", err)
|
||||
}
|
||||
|
||||
cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_")
|
||||
if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") {
|
||||
if err := querySharePointSiteID(c, client.Policy); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法查询 SharePoint 站点 ID", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
}
|
||||
|
||||
func querySharePointSiteID(ctx context.Context, policy *model.Policy) error {
|
||||
client, err := onedrive.NewClient(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id)
|
||||
if err := client.Policy.SaveAndClearCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -207,7 +207,15 @@ func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response
|
||||
|
||||
// 验证与回调会话中是否一致
|
||||
actualPath := strings.TrimPrefix(callbackSession.SavePath, "/")
|
||||
if callbackSession.Size != info.Size || info.GetSourcePath() != actualPath {
|
||||
isSizeCheckFailed := callbackSession.Size != info.Size
|
||||
|
||||
// SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 10 KB 宽容
|
||||
// See: https://github.com/OneDrive/onedrive-api-docs/issues/935
|
||||
if strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") && isSizeCheckFailed && (info.Size > callbackSession.Size) && (info.Size-callbackSession.Size <= 10240) {
|
||||
isSizeCheckFailed = false
|
||||
}
|
||||
|
||||
if isSizeCheckFailed || info.GetSourcePath() != actualPath {
|
||||
fs.Handler.(onedrive.Driver).Client.Delete(context.Background(), []string{info.GetSourcePath()})
|
||||
return serializer.Err(serializer.CodeUploadFailed, "文件信息不一致", err)
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
|
||||
// 上下文
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx = context.WithValue(ctx, fsctx.DisableOverwrite, true)
|
||||
|
||||
// 给文件系统分配钩子
|
||||
fs.Use("BeforeUpload", filesystem.HookValidateFile)
|
||||
@@ -183,6 +184,33 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// Source 重定向到文件的有效原始链接
|
||||
func (service *FileAnonymousGetService) Source(ctx context.Context, c *gin.Context) serializer.Response {
|
||||
fs, err := filesystem.NewAnonymousFileSystem()
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeGroupNotAllowed, err.Error(), err)
|
||||
}
|
||||
defer fs.Recycle()
|
||||
|
||||
// 查找文件
|
||||
err = fs.SetTargetFileByIDs([]uint{service.ID})
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||
}
|
||||
|
||||
// 获取文件流
|
||||
res, err := fs.SignURL(ctx, &fs.FileTarget[0],
|
||||
int64(model.GetIntSetting("preview_timeout", 60)), false)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: -302,
|
||||
Data: res,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDocPreviewSession 创建DOC文件预览会话,返回预览地址
|
||||
func (service *FileIDService) CreateDocPreviewSession(ctx context.Context, c *gin.Context) serializer.Response {
|
||||
// 创建文件系统
|
||||
|
||||
@@ -60,6 +60,13 @@ type ItemDecompressService struct {
|
||||
Dst string `json:"dst" binding:"required,min=1,max=65535"`
|
||||
}
|
||||
|
||||
// ItemPropertyService 获取对象属性服务
|
||||
type ItemPropertyService struct {
|
||||
ID string `binding:"required"`
|
||||
TraceRoot bool `form:"trace_root"`
|
||||
IsFolder bool `form:"is_folder"`
|
||||
}
|
||||
|
||||
// Raw 批量解码HashID,获取原始ID
|
||||
func (service *ItemIDService) Raw() *ItemService {
|
||||
if service.Source != nil {
|
||||
@@ -353,3 +360,100 @@ func (service *ItemRenameService) Rename(ctx context.Context, c *gin.Context) se
|
||||
Code: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProperty 获取对象的属性
|
||||
func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Context) serializer.Response {
|
||||
userCtx, _ := c.Get("user")
|
||||
user := userCtx.(*model.User)
|
||||
|
||||
var props serializer.ObjectProps
|
||||
props.QueryDate = time.Now()
|
||||
|
||||
// 如果是文件对象
|
||||
if !service.IsFolder {
|
||||
res, err := hashid.DecodeHashID(service.ID, hashid.FileID)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotFound, "对象不存在", err)
|
||||
}
|
||||
|
||||
file, err := model.GetFilesByIDs([]uint{res}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到文件", err)
|
||||
}
|
||||
|
||||
props.CreatedAt = file[0].CreatedAt
|
||||
props.UpdatedAt = file[0].UpdatedAt
|
||||
props.Policy = file[0].GetPolicy().Name
|
||||
props.Size = file[0].Size
|
||||
|
||||
// 查找父目录
|
||||
if service.TraceRoot {
|
||||
parent, err := model.GetFoldersByIDs([]uint{file[0].FolderID}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到父目录", err)
|
||||
}
|
||||
|
||||
if err := parent[0].TraceRoot(); err != nil {
|
||||
return serializer.DBErr("无法溯源父目录", err)
|
||||
}
|
||||
|
||||
props.Path = path.Join(parent[0].Position, parent[0].Name)
|
||||
}
|
||||
} else {
|
||||
res, err := hashid.DecodeHashID(service.ID, hashid.FolderID)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotFound, "对象不存在", err)
|
||||
}
|
||||
|
||||
// 如果对象是目录, 先尝试返回缓存结果
|
||||
if cacheRes, ok := cache.Get(fmt.Sprintf("folder_props_%d", res)); ok {
|
||||
return serializer.Response{Data: cacheRes.(serializer.ObjectProps)}
|
||||
}
|
||||
|
||||
folder, err := model.GetFoldersByIDs([]uint{res}, user.ID)
|
||||
if err != nil {
|
||||
return serializer.DBErr("找不到目录", err)
|
||||
}
|
||||
|
||||
props.CreatedAt = folder[0].CreatedAt
|
||||
props.UpdatedAt = folder[0].UpdatedAt
|
||||
|
||||
// 统计子目录
|
||||
childFolders, err := model.GetRecursiveChildFolder([]uint{folder[0].ID},
|
||||
user.ID, true)
|
||||
if err != nil {
|
||||
return serializer.DBErr("无法列取子目录", err)
|
||||
}
|
||||
props.ChildFolderNum = len(childFolders) - 1
|
||||
|
||||
// 统计子文件
|
||||
files, err := model.GetChildFilesOfFolders(&childFolders)
|
||||
if err != nil {
|
||||
return serializer.DBErr("无法列取子文件", err)
|
||||
}
|
||||
|
||||
// 统计子文件个数和大小
|
||||
props.ChildFileNum = len(files)
|
||||
for i := 0; i < len(files); i++ {
|
||||
props.Size += files[i].Size
|
||||
}
|
||||
|
||||
// 查找父目录
|
||||
if service.TraceRoot {
|
||||
if err := folder[0].TraceRoot(); err != nil {
|
||||
return serializer.DBErr("无法溯源父目录", err)
|
||||
}
|
||||
|
||||
props.Path = folder[0].Position
|
||||
}
|
||||
|
||||
// 如果列取对象是目录,则缓存结果
|
||||
cache.Set(fmt.Sprintf("folder_props_%d", res), props,
|
||||
model.GetIntSetting("folder_props_timeout", 300))
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: 0,
|
||||
Data: props,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,27 @@ package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// UserLoginService 管理用户登录的服务
|
||||
type UserLoginService struct {
|
||||
//TODO 细致调整验证规则
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"`
|
||||
CaptchaCode string `form:"captchaCode" json:"captchaCode"`
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"`
|
||||
}
|
||||
|
||||
// UserResetEmailService 发送密码重设邮件服务
|
||||
type UserResetEmailService struct {
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
CaptchaCode string `form:"captchaCode" json:"captchaCode"`
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
}
|
||||
|
||||
// UserResetService 密码重设服务
|
||||
@@ -69,31 +63,15 @@ func (service *UserResetService) Reset(c *gin.Context) serializer.Response {
|
||||
|
||||
// Reset 发送密码重设邮件
|
||||
func (service *UserResetEmailService) Reset(c *gin.Context) serializer.Response {
|
||||
// 检查验证码
|
||||
isCaptchaRequired := model.IsTrueVal(model.GetSettingByName("forget_captcha"))
|
||||
useRecaptcha := model.IsTrueVal(model.GetSettingByName("captcha_IsUseReCaptcha"))
|
||||
recaptchaSecret := model.GetSettingByName("captcha_ReCaptchaSecret")
|
||||
if isCaptchaRequired && !useRecaptcha {
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
return serializer.ParamErr("验证码错误", nil)
|
||||
}
|
||||
} else if isCaptchaRequired && useRecaptcha {
|
||||
captcha, err := recaptcha.NewReCAPTCHA(recaptchaSecret, recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
}
|
||||
err = captcha.Verify(service.CaptchaCode)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
return serializer.ParamErr("验证失败,请刷新网页后再次验证", nil)
|
||||
}
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
if user, err := model.GetUserByEmail(service.UserName); err == nil {
|
||||
|
||||
if user.Status == model.Baned || user.Status == model.OveruseBaned {
|
||||
return serializer.Err(403, "该账号已被封禁", nil)
|
||||
}
|
||||
if user.Status == model.NotActivicated {
|
||||
return serializer.Err(403, "该账号未激活", nil)
|
||||
}
|
||||
// 创建密码重设会话
|
||||
secret := util.RandStringRunes(32)
|
||||
cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600)
|
||||
@@ -145,30 +123,7 @@ func (service *Enable2FA) Login(c *gin.Context) serializer.Response {
|
||||
|
||||
// Login 用户登录函数
|
||||
func (service *UserLoginService) Login(c *gin.Context) serializer.Response {
|
||||
isCaptchaRequired := model.GetSettingByName("login_captcha")
|
||||
useRecaptcha := model.GetSettingByName("captcha_IsUseReCaptcha")
|
||||
recaptchaSecret := model.GetSettingByName("captcha_ReCaptchaSecret")
|
||||
expectedUser, err := model.GetUserByEmail(service.UserName)
|
||||
|
||||
if (model.IsTrueVal(isCaptchaRequired)) && !(model.IsTrueVal(useRecaptcha)) {
|
||||
// TODO 验证码校验
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
return serializer.ParamErr("验证码错误", nil)
|
||||
}
|
||||
} else if (model.IsTrueVal(isCaptchaRequired)) && (model.IsTrueVal(useRecaptcha)) {
|
||||
captcha, err := recaptcha.NewReCAPTCHA(recaptchaSecret, recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
}
|
||||
err = captcha.Verify(service.CaptchaCode)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
return serializer.ParamErr("验证失败,请刷新网页后再次验证", nil)
|
||||
}
|
||||
}
|
||||
|
||||
// 一系列校验
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err)
|
||||
|
||||
@@ -1,54 +1,27 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UserRegisterService 管理用户注册的服务
|
||||
type UserRegisterService struct {
|
||||
//TODO 细致调整验证规则
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"`
|
||||
CaptchaCode string `form:"captchaCode" json:"captchaCode"`
|
||||
UserName string `form:"userName" json:"userName" binding:"required,email"`
|
||||
Password string `form:"Password" json:"Password" binding:"required,min=4,max=64"`
|
||||
}
|
||||
|
||||
// Register 新用户注册
|
||||
func (service *UserRegisterService) Register(c *gin.Context) serializer.Response {
|
||||
// 相关设定
|
||||
options := model.GetSettingByNames("email_active", "reg_captcha")
|
||||
// 检查验证码
|
||||
isCaptchaRequired := model.IsTrueVal(options["reg_captcha"])
|
||||
useRecaptcha := model.IsTrueVal(model.GetSettingByName("captcha_IsUseReCaptcha"))
|
||||
recaptchaSecret := model.GetSettingByName("captcha_ReCaptchaSecret")
|
||||
if isCaptchaRequired && !useRecaptcha {
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
return serializer.ParamErr("验证码错误", nil)
|
||||
}
|
||||
} else if isCaptchaRequired && useRecaptcha {
|
||||
captcha, err := recaptcha.NewReCAPTCHA(recaptchaSecret, recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
}
|
||||
err = captcha.Verify(service.CaptchaCode)
|
||||
if err != nil {
|
||||
util.Log().Error(err.Error())
|
||||
return serializer.ParamErr("验证失败,请刷新网页后再次验证", nil)
|
||||
}
|
||||
}
|
||||
options := model.GetSettingByNames("email_active")
|
||||
|
||||
// 相关设定
|
||||
isEmailRequired := model.IsTrueVal(options["email_active"])
|
||||
@@ -64,10 +37,17 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
|
||||
user.Status = model.NotActivicated
|
||||
}
|
||||
user.GroupID = uint(defaultGroup)
|
||||
|
||||
userNotActivated := false
|
||||
// 创建用户
|
||||
if err := model.DB.Create(&user).Error; err != nil {
|
||||
return serializer.DBErr("此邮箱已被使用", err)
|
||||
//检查已存在使用者是否尚未激活
|
||||
expectedUser, err := model.GetUserByEmail(service.UserName)
|
||||
if expectedUser.Status == model.NotActivicated {
|
||||
userNotActivated = true
|
||||
user = expectedUser
|
||||
} else {
|
||||
return serializer.DBErr("此邮箱已被使用", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送激活邮件
|
||||
@@ -100,8 +80,12 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response
|
||||
if err := email.Send(user.Email, title, body); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法发送激活邮件", err)
|
||||
}
|
||||
|
||||
return serializer.Response{Code: 203}
|
||||
if userNotActivated == true {
|
||||
//原本在上面要抛出的DBErr,放来这边抛出
|
||||
return serializer.DBErr("用户未激活,已重新发送激活邮件", nil)
|
||||
} else {
|
||||
return serializer.Response{Code: 203}
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
@@ -200,8 +201,8 @@ func (service *AvatarService) Get(c *gin.Context) serializer.Response {
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法解析 Gravatar 服务器地址", err)
|
||||
}
|
||||
|
||||
has := md5.Sum([]byte(user.Email))
|
||||
email_lowered := strings.ToLower(user.Email)
|
||||
has := md5.Sum([]byte(email_lowered))
|
||||
avatar, _ := url.Parse(fmt.Sprintf("/avatar/%x?d=mm&s=%s", has, sizes[service.Size]))
|
||||
|
||||
return serializer.Response{
|
||||
|
||||
Reference in New Issue
Block a user