Compare commits

...

34 Commits
3.3.2 ... 3.4.0

Author SHA1 Message Date
HFO4
3948ee7f3a Fix: use X-Cr- as custom header prefix 2021-11-23 21:22:23 +08:00
HFO4
865a801fa8 Update version number 2021-11-22 21:08:35 +08:00
HFO4
05941616df Fix: node should not be refreshed when editing node with status=inactive 2021-11-22 20:48:16 +08:00
HFO4
51b1e5b854 Merge remote-tracking branch 'origin/master' 2021-11-22 20:38:21 +08:00
HFO4
4dbe867020 Fix: failed unit test due to import cycle 2021-11-22 20:38:03 +08:00
WeidiDeng
8c8ad3e149 Fix: WebDAV cannot move and rename at the same time (#1056) 2021-11-22 20:29:45 +08:00
HFO4
fce38209bc Merge remote-tracking branch 'origin/master' 2021-11-22 20:27:07 +08:00
HFO4
700e13384e Fix: using url escape instead of unescape in remote handler (#1051) 2021-11-22 20:23:34 +08:00
HFO4
7fd984f95d Feat: support custom office preview service (Fix #1050) 2021-11-22 20:16:24 +08:00
HFO4
9fc08292a0 Feat: migration DB support custom upgrade scripts 2021-11-22 19:53:42 +08:00
milkice
8c5445a26d Fix problems in Dockerfile (#1059)
Fix js heap out of memory
Fix can't find cloudreve for cloudreve has been renamed to Cloudreve
2021-11-20 17:43:31 +08:00
HFO4
96b84bb5e5 Test: tasks pkg 2021-11-20 17:14:45 +08:00
HFO4
9056ef9171 Test: new changes in 3.4.0 2021-11-20 16:59:29 +08:00
HFO4
532bff820a Test: new modifications in filesystem pkg 2021-11-16 20:54:21 +08:00
HFO4
fcd9eddc54 Test: pkg/cluster 2021-11-16 20:14:27 +08:00
HFO4
6c9967b120 Test: cluster/node.go and controller.go 2021-11-15 20:30:25 +08:00
HFO4
416f4c1dd2 Test: balancer / auth / controller in pkg 2021-11-11 20:56:16 +08:00
HFO4
f0089045d7 Test: aria2 task monitor 100% cover 2021-11-11 19:49:02 +08:00
WeidiDeng
4b88eacb6a Remove unnecessary import "C" (#1048)
There are no C binding in this file. And for users to compile themselves, this line will cause compilation to fail for those who don't need sqlite support.

移除不必要的c binding,使用CGO=0时,这行会导致编译失败。禁用CGO会导致sqlite无法使用,但可以方便编译(不需要安装gcc,跨平台编译方便),这点可以在文档中说明。
2021-11-11 17:45:51 +08:00
kikoqiu
54ed7e43ca Feat: improve thumbnails proformance and GC for local policy (#1044)
* thumb generating improvement

Replace "github.com/nfnt/resize" with "golang.org/x/image/draw". Add thumb task queue to avoid oom when batch thumb operation

* thumb improvement

* Add some tests for thumbnail generation
2021-11-11 17:45:22 +08:00
HFO4
4d7b8685b9 Test: aria2 task monitor
Fix: tmp file not deleted after transfer task failed to create
2021-11-09 20:53:42 +08:00
HFO4
eeee43d569 Test: newly added sb models 2021-11-09 19:29:56 +08:00
HFO4
3064ed60f3 Test: new database models and middlewares 2021-11-08 20:49:07 +08:00
HFO4
e41ec9defa Refactor: move slave pkg inside of cluster
Test: middleware for node communication
2021-11-08 19:54:26 +08:00
HFO4
eaa0f6be91 Update version number 2021-11-07 10:14:30 +08:00
HFO4
5db476634a Fix: deadlock and sync issue in node pool 2021-11-03 21:27:53 +08:00
HFO4
1f06ee3af6 Fix: node cannot be reloaded when db model changes 2021-11-01 19:23:19 +08:00
HFO4
22bbfe7da1 Merge remote-tracking branch 'origin/master' 2021-10-31 09:50:55 +08:00
HFO4
f1dc4c4758 Chore: update ubuntu image version 2021-10-31 09:50:07 +08:00
HFO4
5f861b963a Update submodule version 2021-10-31 09:48:31 +08:00
AaronLiu
056de22edb Feat: aria2 download and transfer in slave node (#1040)
* Feat: retrieve nodes from data table

* Feat: master node ping slave node in REST API

* Feat: master send scheduled ping request

* Feat: inactive nodes recover loop

* Modify: remove database operations from aria2 RPC caller implementation

* Feat: init aria2 client in master node

* Feat: Round Robin load balancer

* Feat: create and monitor aria2 task in master node

* Feat: salve receive and handle heartbeat

* Fix: Node ID will be 0 in download record generated in older version

* Feat: sign request headers with all `X-` prefix

* Feat: API call to slave node will carry meta data in headers

* Feat: call slave aria2 rpc method from master

* Feat: get slave aria2 task status
Feat: encode slave response data using gob

* Feat: aria2 callback to master node / cancel or select task to slave node

* Fix: use dummy aria2 client when caller initialize failed in master node

* Feat: slave aria2 status event callback / salve RPC auth

* Feat: prototype for slave driven filesystem

* Feat: retry for init aria2 client in master node

* Feat: init request client with global options

* Feat: slave receive async task from master

* Fix: competition write in request header

* Refactor: dependency initialize order

* Feat: generic message queue implementation

* Feat: message queue implementation

* Feat: master waiting slave transfer result

* Feat: slave transfer file in stateless policy

* Feat: slave transfer file in slave policy

* Feat: slave transfer file in local policy

* Feat: slave transfer file in OneDrive policy

* Fix: failed to initialize update checker http client

* Feat: list slave nodes for dashboard

* Feat: test aria2 rpc connection in slave

* Feat: add and save node

* Feat: add and delete node in node pool

* Fix: temp file cannot be removed when aria2 task fails

* Fix: delete node in admin panel

* Feat: edit node and get node info

* Modify: delete unused settings
2021-10-31 09:41:56 +08:00
想出网名啦
a3b4a22dbc bug fix: can't connect to postgres database (#992)
* bug fix: can't connect to postgres database

* remove useless arg

* remove vscode setting
2021-10-29 20:30:26 +08:00
WeidiDeng
9ff1b47646 fix webdav prop get (#1023)
修复了displayname为空,potplayer可以正常使用webdav功能
2021-09-27 22:28:36 +08:00
AaronLiu
65c4367689 Revert "delete $name policy (#831)" (#961)
This reverts commit e6959a5026.
2021-07-30 11:22:18 +08:00
141 changed files with 6595 additions and 1505 deletions

View File

@@ -7,7 +7,7 @@ on:
jobs:
build:
name: Build
runs-on: ubuntu-16.04
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.13

View File

@@ -11,7 +11,7 @@ jobs:
test:
name: Test
runs-on: ubuntu-16.04
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.13

1
.gitignore vendored
View File

@@ -27,3 +27,4 @@ version.lock
*.ini
conf/conf.ini
/statik/
.vscode/

View File

@@ -5,6 +5,9 @@ COPY ./assets /assets
WORKDIR /assets
# If encountered problems like JavaScript heap out of memory, please uncomment the following options
ENV NODE_OPTIONS --max_old_space_size=4096
# yarn repo connection is unstable, adjust the network timeout to 10 min.
RUN set -ex \
&& yarn install --network-timeout 600000 \
@@ -42,7 +45,7 @@ ARG TZ="Asia/Shanghai"
ENV TZ ${TZ}
COPY --from=be-builder /go/bin/cloudreve /cloudreve/cloudreve
COPY --from=be-builder /go/bin/Cloudreve /cloudreve/cloudreve
RUN apk upgrade \
&& apk add bash tzdata \

2
assets

Submodule assets updated: 59890e6b22...691e82868d

View File

@@ -34,7 +34,7 @@ type GitHubRelease struct {
// CheckUpdate 检查更新
func CheckUpdate() {
client := request.HTTPClient{}
client := request.NewClient()
res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
if err != nil {
util.Log().Warning("更新检查失败, %s", err)

View File

@@ -2,12 +2,15 @@ package bootstrap
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/models/scripts"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/gin-gonic/gin"
)
@@ -20,14 +23,92 @@ func Init(path string) {
if !conf.SystemConfig.Debug {
gin.SetMode(gin.ReleaseMode)
}
cache.Init()
if conf.SystemConfig.Mode == "master" {
model.Init()
task.Init()
aria2.Init(false)
email.Init()
crontab.Init()
InitStatic()
dependencies := []struct {
mode string
factory func()
}{
{
"both",
func() {
scripts.Init()
},
},
{
"both",
func() {
cache.Init()
},
},
{
"master",
func() {
model.Init()
},
},
{
"both",
func() {
task.Init()
},
},
{
"master",
func() {
cluster.Init()
},
},
{
"master",
func() {
aria2.Init(false, cluster.Default, mq.GlobalMQ)
},
},
{
"master",
func() {
email.Init()
},
},
{
"master",
func() {
crontab.Init()
},
},
{
"master",
func() {
InitStatic()
},
},
{
"slave",
func() {
cluster.InitController()
},
},
{
"both",
func() {
auth.Init()
},
},
}
auth.Init()
for _, dependency := range dependencies {
switch dependency.mode {
case "master":
if conf.SystemConfig.Mode == "master" {
dependency.factory()
}
case "slave":
if conf.SystemConfig.Mode == "slave" {
dependency.factory()
}
default:
dependency.factory()
}
}
}

View File

@@ -2,14 +2,14 @@ package bootstrap
import (
"context"
"github.com/cloudreve/Cloudreve/v3/models/scripts"
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
func RunScript(name string) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := scripts.RunDBScript(name, ctx); err != nil {
if err := invoker.RunDBScript(name, ctx); err != nil {
util.Log().Error("数据库脚本执行失败: %s", err)
return
}

4
go.mod
View File

@@ -16,6 +16,7 @@ require (
github.com/gin-gonic/gin v1.5.0
github.com/go-ini/ini v1.50.0
github.com/go-mail/mail v2.3.1+incompatible
github.com/gofrs/uuid v4.0.0+incompatible
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-querystring v1.0.0
github.com/gorilla/websocket v1.4.1
@@ -37,7 +38,8 @@ require (
github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible
github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac
github.com/upyun/go-sdk v2.1.0+incompatible
golang.org/x/text v0.3.2
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/text v0.3.6
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/go-playground/validator.v9 v9.29.1
gopkg.in/ini.v1 v1.51.0 // indirect

7
go.sum
View File

@@ -12,6 +12,7 @@ github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7I
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/aliyun/aliyun-oss-go-sdk v2.0.0/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible h1:A3oZlWPD/Poa19FvNbw+Zu4yKAurDBTjlRDilYGBiS4=
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
@@ -77,6 +78,8 @@ github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
@@ -246,6 +249,8 @@ golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -286,6 +291,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuArfcOvC4AoJmILihzhDg=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -22,16 +22,14 @@ import (
)
// SignRequired 验证请求签名
func SignRequired() gin.HandlerFunc {
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
return func(c *gin.Context) {
var err error
switch c.Request.Method {
case "PUT", "POST":
err = auth.CheckRequest(auth.General, c.Request)
// TODO 生产环境去掉下一行
//err = nil
case "PUT", "POST", "PATCH":
err = auth.CheckRequest(authInstance, c.Request)
default:
err = auth.CheckURI(auth.General, c.Request.URL)
err = auth.CheckURI(authInstance, c.Request.URL)
}
if err != nil {
@@ -39,6 +37,7 @@ func SignRequired() gin.HandlerFunc {
c.Abort()
return
}
c.Next()
}
}

View File

@@ -87,19 +87,30 @@ func TestAuthRequired(t *testing.T) {
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired()
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
@@ -781,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
@@ -790,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) {
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
}

61
middleware/cluster.go Normal file
View File

@@ -0,0 +1,61 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
"strconv"
)
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader("X-Cr-Site-Id"))
c.Set("MasterSiteURL", c.GetHeader("X-Cr-Site-Url"))
c.Set("MasterVersion", c.GetHeader("X-Cr-Cloudreve-Version"))
c.Next()
}
}
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
return
}
c.Set("MasterAria2Instance", caller)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
c.Abort()
}
}
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Cr-Node-Id"), 10, 64)
if err != nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
SignRequired(slaveNode.MasterAuthInstance())(c)
}
}

120
middleware/cluster_test.go Normal file
View File

@@ -0,0 +1,120 @@
package middleware
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header = map[string][]string{
"X-Cr-Site-Id": {"expectedSiteID"},
"X-Cr-Site-Url": {"expectedSiteURL"},
"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")
a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}
func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()
// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
// MasterSiteID not set
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
}
// Cannot get aria2 instances
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
testController.AssertExpectations(t)
}
// Success
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
useSlaveAria2InstanceFunc(c)
a.False(c.IsAborted())
res, _ := c.Get("MasterAria2Instance")
a.NotNil(res)
testController.AssertExpectations(t)
}
}

View File

@@ -24,6 +24,7 @@ type Download struct {
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
UserID uint // 发起者UID
TaskID uint // 对应的转存任务ID
NodeID uint // 处理任务的节点ID
// 关联模型
User *User `gorm:"PRELOAD:false,association_autoupdate:false"`
@@ -114,3 +115,13 @@ func (task *Download) GetOwner() *User {
func (download *Download) Delete() error {
return DB.Model(download).Delete(download).Error
}
// GetNodeID 返回任务所属节点ID
func (task *Download) GetNodeID() uint {
// 兼容3.4版本之前生成的下载记录
if task.NodeID == 0 {
return 1
}
return task.NodeID
}

View File

@@ -177,3 +177,14 @@ func TestDownload_Delete(t *testing.T) {
}
}
func TestDownload_GetNodeID(t *testing.T) {
a := assert.New(t)
record := Download{}
// compatible with 3.4
a.EqualValues(1, record.GetNodeID())
record.NodeID = 5
a.EqualValues(5, record.GetNodeID())
}

View File

@@ -35,7 +35,14 @@ func Init() {
case "UNSET", "sqlite", "sqlite3":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite3 数据库
db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile))
case "mysql", "postgres", "mssql":
case "postgres":
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Port))
case "mysql", "mssql":
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,

View File

@@ -1,11 +1,17 @@
package model
import (
"context"
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/fatih/color"
"github.com/gofrs/uuid"
"github.com/hashicorp/go-version"
"github.com/jinzhu/gorm"
"sort"
"strings"
)
// 是否需要迁移
@@ -34,8 +40,9 @@ func migration() {
if conf.DatabaseConfig.Type == "mysql" {
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
}
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{})
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{})
// 创建初始存储策略
addDefaultPolicy()
@@ -46,9 +53,15 @@ func migration() {
// 创建初始管理员账户
addDefaultUser()
// 创建初始节点
addDefaultNode()
// 向设置数据表添加初始设置
addDefaultSettings()
// 执行数据库升级脚本
execUpgradeScripts()
util.Log().Info("数据库初始化结束")
}
@@ -73,6 +86,8 @@ func addDefaultPolicy() {
}
func addDefaultSettings() {
siteID, _ := uuid.NewV4()
defaultSettings := []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
@@ -83,6 +98,7 @@ func addDefaultSettings() {
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: siteID.String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
@@ -100,10 +116,13 @@ func addDefaultSettings() {
{Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"},
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
{Name: "slave_node_retry", Value: `3`, Type: "slave"},
{Name: "slave_ping_interval", Value: `60`, Type: "slave"},
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "aria2_call_timeout", Value: `5`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
@@ -131,11 +150,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
{Name: "aria2_token", Value: ``, Type: "aria2"},
{Name: "aria2_rpcurl", Value: ``, Type: "aria2"},
{Name: "aria2_temp_path", Value: ``, Type: "aria2"},
{Name: "aria2_options", Value: `{}`, Type: "aria2"},
{Name: "aria2_interval", Value: `60`, Type: "aria2"},
{Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
@@ -175,6 +189,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
}
for _, value := range defaultSettings {
@@ -265,3 +280,36 @@ func addDefaultUser() {
util.Log().Info("初始管理员密码:" + c.Sprint(password))
}
}
func addDefaultNode() {
_, err := GetNodeByID(1)
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Node{
Name: "主机(本机)",
Status: NodeActive,
Type: MasterNodeType,
Aria2OptionsSerialized: Aria2Option{
Interval: 10,
Timeout: 10,
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("无法创建初始节点记录, %s", err)
}
}
}
func execUpgradeScripts() {
s := invoker.ListPrefix("UpgradeTo")
versions := make([]*version.Version, len(s))
for i, raw := range s {
v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo"))
versions[i] = v
}
sort.Sort(version.Collection(versions))
for i := 0; i < len(versions); i++ {
invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background())
}
}

91
models/node.go Normal file
View File

@@ -0,0 +1,91 @@
package model
import (
"encoding/json"
"github.com/jinzhu/gorm"
)
// Node 从机节点信息模型
type Node struct {
gorm.Model
Status NodeStatus // 节点状态
Name string // 节点别名
Type ModelType // 节点状态
Server string // 服务器地址
SlaveKey string `gorm:"type:text"` // 主->从 通信密钥
MasterKey string `gorm:"type:text"` // 从->主 通信密钥
Aria2Enabled bool // 是否支持用作离线下载节点
Aria2Options string `gorm:"type:text"` // 离线下载配置
Rank int // 负载均衡权重
// 数据库忽略字段
Aria2OptionsSerialized Aria2Option `gorm:"-"`
}
// Aria2Option 非公有的Aria2配置属性
type Aria2Option struct {
// RPC 服务器地址
Server string `json:"server,omitempty"`
// RPC 密钥
Token string `json:"token,omitempty"`
// 临时下载目录
TempPath string `json:"temp_path,omitempty"`
// 附加下载配置
Options string `json:"options,omitempty"`
// 下载监控间隔
Interval int `json:"interval,omitempty"`
// RPC API 请求超时
Timeout int `json:"timeout,omitempty"`
}
type NodeStatus int
type ModelType int
const (
NodeActive NodeStatus = iota
NodeSuspend
)
const (
SlaveNodeType ModelType = iota
MasterNodeType
)
// GetNodeByID 用ID获取节点
func GetNodeByID(ID interface{}) (Node, error) {
var node Node
result := DB.First(&node, ID)
return node, result.Error
}
// GetNodesByStatus 根据给定状态获取节点
func GetNodesByStatus(status ...NodeStatus) ([]Node, error) {
var nodes []Node
result := DB.Where("status in (?)", status).Find(&nodes)
return nodes, result.Error
}
// AfterFind 找到节点后的钩子
func (node *Node) AfterFind() (err error) {
// 解析离线下载设置到 Aria2OptionsSerialized
if node.Aria2Options != "" {
err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized)
}
return err
}
// BeforeSave Save策略前的钩子
func (node *Node) BeforeSave() (err error) {
optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized)
node.Aria2Options = string(optionsValue)
return err
}
// SetStatus 设置节点启用状态
func (node *Node) SetStatus(status NodeStatus) error {
node.Status = status
return DB.Model(node).Updates(map[string]interface{}{
"status": status,
}).Error
}

64
models/node_test.go Normal file
View File

@@ -0,0 +1,64 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestGetNodeByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetNodeByID(1)
a.NoError(err)
a.EqualValues(1, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestGetNodesByStatus(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive))
res, err := GetNodesByStatus(NodeActive)
a.NoError(err)
a.Len(res, 1)
a.EqualValues(NodeActive, res[0].Status)
a.NoError(mock.ExpectationsWereMet())
}
func TestNode_AfterFind(t *testing.T) {
a := assert.New(t)
node := &Node{}
// No aria2 options
{
a.NoError(node.AfterFind())
}
// with aria2 options
{
node.Aria2Options = `{"timeout":1}`
a.NoError(node.AfterFind())
a.Equal(1, node.Aria2OptionsSerialized.Timeout)
}
}
func TestNode_BeforeSave(t *testing.T) {
a := assert.New(t)
node := &Node{}
node.Aria2OptionsSerialized.Timeout = 1
a.NoError(node.BeforeSave())
a.Contains(node.Aria2Options, "1")
}
func TestNode_SetStatus(t *testing.T) {
a := assert.New(t)
node := &Node{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(node.SetStatus(NodeActive))
a.Equal(NodeActive, node.Status)
a.NoError(mock.ExpectationsWereMet())
}

View File

@@ -37,6 +37,7 @@ type Policy struct {
// 数据库忽略字段
OptionsSerialized PolicyOption `gorm:"-"`
MasterID string `gorm:"-"`
}
// PolicyOption 非公有的存储策略属性
@@ -277,6 +278,13 @@ func (policy *Policy) SaveAndClearCache() error {
return err
}
// SaveAndClearCache 更新并清理缓存
func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error {
err := DB.Model(policy).UpdateColumn("access_key", s).Error
policy.ClearCache()
return err
}
// ClearCache 清空policy缓存
func (policy *Policy) ClearCache() {
cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_")

View File

@@ -327,3 +327,18 @@ func TestPolicy_IsThumbExist(t *testing.T) {
asserts.Equal(testCase.expect, policy.IsThumbExist(testCase.name))
}
}
func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) {
a := assert.New(t)
cache.Set("policy_1331", Policy{}, 3600)
p := &Policy{}
p.ID = 1331
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(p.UpdateAccessKeyAndClearCache("ak"))
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get("policy_1331")
a.False(ok)
}

9
models/scripts/init.go Normal file
View File

@@ -0,0 +1,9 @@
package scripts
import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
func Init() {
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
}

View File

@@ -1,25 +0,0 @@
package scripts
import (
"context"
"fmt"
)
type DBScript interface {
Run(ctx context.Context)
}
var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok {
script.Run(ctx)
return nil
}
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
}
func register(name string, script DBScript) {
availableScripts[name] = script
}

View File

@@ -0,0 +1,38 @@
package invoker
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"strings"
)
type DBScript interface {
Run(ctx context.Context)
}
var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok {
util.Log().Info("开始执行数据库脚本 [%s]", name)
script.Run(ctx)
return nil
}
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
}
func Register(name string, script DBScript) {
availableScripts[name] = script
}
func ListPrefix(prefix string) []string {
var scripts []string
for name := range availableScripts {
if strings.HasPrefix(name, prefix) {
scripts = append(scripts, name)
}
}
return scripts
}

View File

@@ -0,0 +1,39 @@
package invoker
import (
"context"
"github.com/stretchr/testify/assert"
"testing"
)
type TestScript int
func (script TestScript) Run(ctx context.Context) {
}
func TestRunDBScript(t *testing.T) {
asserts := assert.New(t)
Register("test", TestScript(0))
// 不存在
{
asserts.Error(RunDBScript("else", context.Background()))
}
// 存在
{
asserts.NoError(RunDBScript("test", context.Background()))
}
}
func TestListPrefix(t *testing.T) {
asserts := assert.New(t)
Register("U1", TestScript(0))
Register("U2", TestScript(0))
Register("U3", TestScript(0))
Register("P1", TestScript(0))
res := ListPrefix("U")
asserts.Len(res, 3)
}

View File

@@ -1,49 +0,0 @@
package scripts
import (
"context"
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
var mockDB *gorm.DB
type TestScript int
func (script TestScript) Run(ctx context.Context) {
}
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
mockDB = model.DB
defer db.Close()
m.Run()
}
func TestRunDBScript(t *testing.T) {
asserts := assert.New(t)
register("test", TestScript(0))
// 不存在
{
asserts.Error(RunDBScript("else", context.Background()))
}
// 存在
{
asserts.NoError(RunDBScript("test", context.Background()))
}
}

View File

@@ -9,10 +9,6 @@ import (
type ResetAdminPassword int
func init() {
register("ResetAdminPassword", ResetAdminPassword(0))
}
// Run 运行脚本从社区版升级至 Pro 版
func (script ResetAdminPassword) Run(ctx context.Context) {
// 查找用户

View File

@@ -8,10 +8,6 @@ import (
type UserStorageCalibration int
func init() {
register("CalibrateUserStorage", UserStorageCalibration(0))
}
type storageResult struct {
Total uint64
}

View File

@@ -2,11 +2,31 @@ package scripts
import (
"context"
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
var mockDB *gorm.DB
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
mockDB = model.DB
defer db.Close()
m.Run()
}
func TestUserStorageCalibration_Run(t *testing.T) {
asserts := assert.New(t)
script := UserStorageCalibration(0)

43
models/scripts/upgrade.go Normal file
View File

@@ -0,0 +1,43 @@
package scripts
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"strconv"
)
type UpgradeTo340 int
// Run upgrade from older version to 3.4.0
func (script UpgradeTo340) Run(ctx context.Context) {
// 取回老版本 aria2 设定
old := model.GetSettingByType([]string{"aria2"})
if len(old) == 0 {
return
}
// 写入到新版本的节点设定
n, err := model.GetNodeByID(1)
if err != nil {
util.Log().Error("找不到主机节点, %s", err)
}
n.Aria2Enabled = old["aria2_rpcurl"] != ""
n.Aria2OptionsSerialized.Options = old["aria2_options"]
n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"]
interval, err := strconv.Atoi(old["aria2_interval"])
if err != nil {
interval = 10
}
n.Aria2OptionsSerialized.Interval = interval
n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"]
n.Aria2OptionsSerialized.Token = old["aria2_token"]
if err := model.DB.Save(&n).Error; err != nil {
util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err)
} else {
model.DB.Where("type = ?", "aria2").Delete(model.Setting{})
util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式")
}
}

View File

@@ -0,0 +1,66 @@
package scripts
import (
"context"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUpgradeTo340_Run(t *testing.T) {
a := assert.New(t)
script := UpgradeTo340(0)
// skip
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// node not found
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// success
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// failed
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -30,12 +30,16 @@ func GetSettingByName(name string) string {
if optionValue, ok := cache.Get(cacheKey); ok {
return optionValue.(string)
}
// 尝试数据库中查找
result := DB.Where("name = ?", name).First(&setting)
if result.Error == nil {
_ = cache.Set(cacheKey, setting.Value, -1)
return setting.Value
if DB != nil {
result := DB.Where("name = ?", name).First(&setting)
if result.Error == nil {
_ = cache.Set(cacheKey, setting.Value, -1)
return setting.Value
}
}
return ""
}

View File

@@ -188,18 +188,6 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
asserts.Error(share.CanBeDownloadBy(user))
}
// 未登录,需要积分
{
user := &User{
Group: Group{
OptionsSerialized: GroupOption{
ShareDownload: true,
},
},
}
asserts.Error(share.CanBeDownloadBy(user))
}
// 成功
{
user := &User{

View File

@@ -55,9 +55,9 @@ func TestGetTagsByUID(t *testing.T) {
func TestGetTagsByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetTasksByID(1)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag"))
res, err := GetTagsByUID(1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, res.ID)
asserts.EqualValues("tag", res[0].Name)
}

View File

@@ -91,3 +91,14 @@ func TestListTasks(t *testing.T) {
asserts.EqualValues(5, total)
asserts.Len(res, 1)
}
func TestGetTasksByStatus(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res := GetTasksByStatus(1, 2)
a.NoError(mock.ExpectationsWereMet())
a.Len(res, 1)
}

View File

@@ -35,7 +35,7 @@ type User struct {
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-",gorm:"type:text"`
Options string `json:"-" gorm:"type:text"`
Authn string `gorm:"type:text"`
// 关联模型

View File

@@ -177,10 +177,10 @@ func TestNewUser(t *testing.T) {
func TestUser_AfterFind(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"1"}, "policy_")
cache.Deletes([]string{"0"}, "policy_")
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "默认存储策略")
AddRow(144, "默认存储策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
newUser := NewUser()
@@ -240,11 +240,6 @@ func TestUser_GetRemainingCapacity(t *testing.T) {
newUser.Group.MaxStorage = 100
newUser.Storage = 200
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
cache.Set("pack_size_0", uint64(10), 0)
newUser.Group.MaxStorage = 100
newUser.Storage = 101
asserts.Equal(uint64(9), newUser.GetRemainingCapacity())
}
func TestUser_DeductionCapacity(t *testing.T) {
@@ -280,10 +275,6 @@ func TestUser_DeductionCapacity(t *testing.T) {
asserts.Equal(false, newUser.IncreaseStorage(1))
asserts.Equal(uint64(100), newUser.Storage)
cache.Set("pack_size_1", uint64(1), 0)
asserts.Equal(true, newUser.IncreaseStorage(1))
asserts.Equal(uint64(101), newUser.Storage)
asserts.True(newUser.IncreaseStorage(0))
}

View File

@@ -55,6 +55,6 @@ func TestDeleteWebDAVAccountByID(t *testing.T) {
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NoError(DeleteTagByID(1, 1))
DeleteWebDAVAccountByID(1, 1)
asserts.NoError(mock.ExpectationsWereMet())
}

View File

@@ -1,169 +1,67 @@
package aria2
import (
"encoding/json"
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
)
// Instance 默认使用的Aria2处理实例
var Instance Aria2 = &DummyAria2{}
var Instance common.Aria2 = &common.DummyAria2{}
// LB 获取 Aria2 节点的负载均衡器
var LB balancer.Balancer
// Lock Instance的读写锁
var Lock sync.RWMutex
// EventNotifier 任务状态更新通知处理
var EventNotifier = &Notifier{}
// Aria2 离线下载处理接口
type Aria2 interface {
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) error
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error {
return ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
// GetLoadBalancer 返回供Aria2使用的负载均衡
func GetLoadBalancer() balancer.Balancer {
Lock.RLock()
defer Lock.RUnlock()
return LB
}
// Init 初始化
func Init(isReload bool) {
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
Lock.Lock()
defer Lock.Unlock()
// 关闭上个初始连接
if previousClient, ok := Instance.(*RPCService); ok {
if previousClient.Caller != nil {
util.Log().Debug("关闭上个 aria2 连接")
previousClient.Caller.Close()
}
}
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
timeout := model.GetIntSetting("aria2_call_timeout", 5)
if options["aria2_rpcurl"] == "" {
Instance = &DummyAria2{}
return
}
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
client := &RPCService{}
// 解析RPC服务地址
server, err := url.Parse(options["aria2_rpcurl"])
if err != nil {
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
Instance = &DummyAria2{}
return
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
if err != nil {
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
Instance = &DummyAria2{}
return
}
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
Instance = &DummyAria2{}
return
}
Instance = client
LB = balancer.NewBalancer("RoundRobin")
Lock.Unlock()
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
NewMonitor(&unfinished[i])
monitor.NewMonitor(&unfinished[i], pool, mqClient)
}
}
}
// getStatus 将给定的状态字符串转换为状态标识数字
func getStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
// 解析RPC服务地址
rpcServer, err := url.Parse(server)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
}
rpcServer.Path = "/jsonrpc"
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
}
return caller.GetVersion()
}

View File

@@ -2,13 +2,15 @@ package aria2
import (
"database/sql"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
@@ -26,66 +28,39 @@ func TestMain(m *testing.M) {
m.Run()
}
func TestDummyAria2(t *testing.T) {
asserts := assert.New(t)
instance := DummyAria2{}
asserts.Error(instance.CreateTask(nil, nil))
_, err := instance.Status(nil)
asserts.Error(err)
asserts.Error(instance.Cancel(nil))
asserts.Error(instance.Select(nil, nil))
}
func TestInit(t *testing.T) {
MAX_RETRY = 0
asserts := assert.New(t)
cache.Set("setting_aria2_token", "1", 0)
cache.Set("setting_aria2_call_timeout", "5", 0)
cache.Set("setting_aria2_options", `[]`, 0)
a := assert.New(t)
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
mockQueue := mq.NewMQ()
// 未指定RPC地址跳过
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
Init(false, mockPool, mockQueue)
a.NoError(mock.ExpectationsWereMet())
mockPool.AssertExpectations(t)
}
func TestTestRPCConnection(t *testing.T) {
a := assert.New(t)
// url not legal
{
cache.Set("setting_aria2_rpcurl", "", 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
a.Error(err)
a.Empty(res.Version)
}
// 无法解析服务器地址
// rpc failed
{
cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
}
// 无法解析全局配置
{
Instance = &RPCService{}
cache.Set("setting_aria2_options", "?", 0)
cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
Init(false)
asserts.IsType(&DummyAria2{}, Instance)
}
// 连接失败
{
cache.Set("setting_aria2_options", "{}", 0)
cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
cache.Set("setting_aria2_call_timeout", "1", 0)
cache.Set("setting_aria2_interval", "100", 0)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
Init(false)
asserts.NoError(mock.ExpectationsWereMet())
asserts.IsType(&RPCService{}, Instance)
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
a.Error(err)
a.Empty(res.Version)
}
}
func TestGetStatus(t *testing.T) {
asserts := assert.New(t)
asserts.Equal(4, getStatus("complete"))
asserts.Equal(1, getStatus("active"))
asserts.Equal(0, getStatus("waiting"))
asserts.Equal(2, getStatus("paused"))
asserts.Equal(3, getStatus("error"))
asserts.Equal(5, getStatus("removed"))
asserts.Equal(6, getStatus("?"))
func TestGetLoadBalancer(t *testing.T) {
a := assert.New(t)
a.NotPanics(func() {
GetLoadBalancer()
})
}

View File

@@ -1,123 +0,0 @@
package aria2
import (
"context"
"path/filepath"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// RPCService 通过RPC服务的Aria2任务管理器
type RPCService struct {
options *clientOptions
Caller rpc.Client
}
type clientOptions struct {
Options map[string]interface{} // 创建下载时额外添加的设置
}
// Init 初始化
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
// 客户端已存在,则关闭先前连接
if client.Caller != nil {
client.Caller.Close()
}
client.options = &clientOptions{
Options: options,
}
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
EventNotifier)
client.Caller = caller
return err
}
// Status 查询下载状态
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := client.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s10秒钟后重试", err)
time.Sleep(time.Duration(10) * time.Second)
res, err = client.Caller.TellStatus(task.GID)
}
return res, err
}
// Cancel 取消下载
func (client *RPCService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := client.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
//// 删除临时文件
//util.Log().Debug("离线下载任务[%s]已取消1 分钟后删除临时文件", task.GID)
//go func(task *model.Download) {
// select {
// case <-time.After(time.Duration(60) * time.Second):
// err := os.RemoveAll(task.Parent)
// if err != nil {
// util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
// }
// }
//}(task)
return err
}
// Select 选取要下载的文件
func (client *RPCService) Select(task *model.Download, files []int) error {
var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
// CreateTask 创建新任务
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error {
// 生成存储路径
path := filepath.Join(
model.GetSettingByName("aria2_temp_path"),
"aria2",
strconv.FormatInt(time.Now().UnixNano(), 10),
)
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range client.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := client.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return err
}
// 保存到数据库
task.GID = gid
_, err = task.Create()
if err != nil {
return err
}
// 创建任务监控
NewMonitor(task)
return nil
}

View File

@@ -1,52 +0,0 @@
package aria2
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/stretchr/testify/assert"
)
func TestRPCService_Init(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.Error(caller.Init("ws://", "", 1, nil))
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
}
func TestRPCService_Status(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
_, err := caller.Status(&model.Download{})
asserts.Error(err)
}
func TestRPCService_Cancel(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
err := caller.Cancel(&model.Download{Parent: "test"})
asserts.Error(err)
}
func TestRPCService_Select(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
asserts.Error(err)
}
func TestRPCService_CreateTask(t *testing.T) {
asserts := assert.New(t)
caller := &RPCService{}
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
cache.Set("setting_aria2_temp_path", "test", 0)
err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
asserts.Error(err)
}

114
pkg/aria2/common/common.go Normal file
View File

@@ -0,0 +1,114 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Aria2 离线下载处理接口
type Aria2 interface {
// Init 初始化客户端连接
Init() error
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
// 获取离线下载配置
GetConfig() model.Aria2Option
// 删除临时下载文件
DeleteTempFile(*model.Download) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
func (instance *DummyAria2) Init() error {
return nil
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
return "", ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
}
// GetConfig 返回空的
func (instance *DummyAria2) GetConfig() model.Aria2Option {
return model.Aria2Option{}
}
// GetConfig 返回空的
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
return ErrNotEnabled
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
}
}

View File

@@ -0,0 +1,45 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDummyAria2(t *testing.T) {
a := assert.New(t)
d := &DummyAria2{}
a.NoError(d.Init())
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
a.Empty(res)
a.Error(err)
_, err = d.Status(&model.Download{})
a.Error(err)
err = d.Cancel(&model.Download{})
a.Error(err)
err = d.Select(&model.Download{}, []int{})
a.Error(err)
configRes := d.GetConfig()
a.NotNil(configRes)
err = d.DeleteTempFile(&model.Download{})
a.Error(err)
}
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus("complete"), Complete)
a.Equal(GetStatus("active"), Downloading)
a.Equal(GetStatus("waiting"), Ready)
a.Equal(GetStatus("paused"), Paused)
a.Equal(GetStatus("error"), Error)
a.Equal(GetStatus("removed"), Canceled)
a.Equal(GetStatus("unknown"), Unknown)
}

View File

@@ -1,19 +1,21 @@
package aria2
package monitor
import (
"context"
"encoding/json"
"errors"
"os"
"path/filepath"
"strconv"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
@@ -23,35 +25,37 @@ type Monitor struct {
Task *model.Download
Interval time.Duration
notifier chan StatusEvent
notifier <-chan mq.Message
node cluster.Node
retried int
}
// StatusEvent 状态改变事件
type StatusEvent struct {
GID string
Status int
}
var MAX_RETRY = 10
// NewMonitor 新建上传状态监控
func NewMonitor(task *model.Download) {
// NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
monitor := &Monitor{
Task: task,
Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second,
notifier: make(chan StatusEvent),
notifier: make(chan mq.Message),
node: pool.GetNodeByID(task.GetNodeID()),
}
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop(mqClient)
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
} else {
monitor.setErrorStatus(errors.New("节点不可用"))
}
go monitor.Loop()
EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
}
// Loop 开启监控循环
func (monitor *Monitor) Loop() {
defer EventNotifier.Unsubscribe(monitor.Task.GID)
func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
// 首次循环立即更新
interval := time.Duration(0)
interval := 50 * time.Millisecond
for {
select {
@@ -70,9 +74,7 @@ func (monitor *Monitor) Loop() {
// Update 更新状态,返回值表示是否退出监控
func (monitor *Monitor) Update() bool {
Lock.RLock()
status, err := Instance.Status(monitor.Task)
Lock.RUnlock()
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
if err != nil {
monitor.retried++
@@ -102,6 +104,7 @@ func (monitor *Monitor) Update() bool {
if err := monitor.UpdateTaskInfo(status); err != nil {
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s]", monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
@@ -109,13 +112,13 @@ func (monitor *Monitor) Update() bool {
switch status.Status {
case "complete":
return monitor.Complete(status)
return monitor.Complete(task.TaskPoll)
case "error":
return monitor.Error(status)
case "active", "waiting", "paused":
return false
case "removed":
monitor.Task.Status = Canceled
monitor.Task.Status = common.Canceled
monitor.Task.Save()
monitor.RemoveTempFolder()
return true
@@ -130,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid
monitor.Task.Status = getStatus(status.Status)
monitor.Task.Status = common.GetStatus(status.Status)
// 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@@ -164,9 +167,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
// 文件大小更新后,对文件限制等进行校验
if err := monitor.ValidateFile(); err != nil {
// 验证失败时取消任务
Lock.RLock()
Instance.Cancel(monitor.Task)
Lock.RUnlock()
monitor.node.GetAria2Instance().Cancel(monitor.Task)
return err
}
}
@@ -179,7 +180,7 @@ func (monitor *Monitor) ValidateFile() error {
// 找到任务创建者
user := monitor.Task.GetOwner()
if user == nil {
return ErrUserNotFound
return common.ErrUserNotFound
}
// 创建文件系统
@@ -230,36 +231,40 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
// RemoveTempFolder 清理下载临时目录
func (monitor *Monitor) RemoveTempFolder() {
err := os.RemoveAll(monitor.Task.Parent)
if err != nil {
util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err)
}
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
}
// Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
func (monitor *Monitor) Complete(pool task.Pool) bool {
// 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
if monitor.Task.StatusInfo.Files[i].Selected == "true" {
file = append(file, monitor.Task.StatusInfo.Files[i].Path)
fileInfo := monitor.Task.StatusInfo.Files[i]
if fileInfo.Selected == "true" {
file = append(file, fileInfo.Path)
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
sizes[fileInfo.Path] = size
}
}
job, err := task.NewTransferTask(
monitor.Task.UserID,
file,
monitor.Task.Dst,
monitor.Task.Parent,
true,
monitor.node.ID(),
sizes,
)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交中转任务
task.TaskPoll.Submit(job)
pool.Submit(job)
// 更新任务ID
monitor.Task.TaskID = job.Model().ID
@@ -269,7 +274,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
}
func (monitor *Monitor) setErrorStatus(err error) {
monitor.Task.Status = Error
monitor.Task.Status = common.Error
monitor.Task.Error = err.Error()
monitor.Task.Save()
}

View File

@@ -0,0 +1,438 @@
package monitor
import (
"database/sql"
"errors"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestNewMonitor(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
// node not available
{
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task := &model.Download{
Model: gorm.Model{ID: 1},
}
NewMonitor(task, mockPool, mockMQ)
mockPool.AssertExpectations(t)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(task.Error)
}
// success
{
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
task := &model.Download{
Model: gorm.Model{ID: 1},
}
NewMonitor(task, mockPool, mockMQ)
mockNode.AssertExpectations(t)
mockPool.AssertExpectations(t)
}
}
func TestMonitor_Loop(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
retried: MAX_RETRY,
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
notifier: mockMQ.Subscribe("test", 1),
}
// into interval loop
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
// into notifier loop
{
m.Task.Error = ""
mockMQ.Publish("test", mq.Message{})
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
}
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
a := assert.New(t)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
for i := 0; i < MAX_RETRY; i++ {
a.False(m.Update())
}
mockNode.AssertExpectations(t)
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"next"},
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.Equal("next", m.Task.GID)
mockAria2.AssertExpectations(t)
}
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateCompleted(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "complete",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
mockNode.On("ID").Return(uint(1))
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateError(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "error",
ErrorMessage: "error",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateActive(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "active",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateRemoved(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "removed",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.Equal(common.Canceled, m.Task.Status)
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateUnknown(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "unknown",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) {
a := assert.New(t)
status := rpc.StatusInfo{
Status: "completed",
TotalLength: "100",
CompletedLength: "50",
DownloadSpeed: "20",
}
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := m.UpdateTaskInfo(status)
a.Error(err)
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)
}
func TestMonitor_ValidateFile(t *testing.T) {
a := assert.New(t)
m := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
TotalSize: 100,
},
}
// failed to create filesystem
{
m.Task.User = &model.User{
Policy: model.Policy{
Type: "random",
},
}
a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile())
}
// User capacity not enough
{
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 99,
},
Policy: model.Policy{
Type: "local",
},
}
a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile())
}
// single file too big
{
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 100,
},
Policy: model.Policy{
Type: "local",
MaxSize: 99,
},
}
a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile())
}
// all pass
{
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 100,
},
Policy: model.Policy{
Type: "local",
MaxSize: 100,
},
}
a.NoError(m.ValidateFile())
}
}
func TestMonitor_Complete(t *testing.T) {
a := assert.New(t)
mockNode := &mocks.NodeMock{}
mockNode.On("ID").Return(uint(1))
mockPool := &mocks.TaskPoolMock{}
mockPool.On("Submit", testMock.Anything)
m := &Monitor{
node: mockNode,
Task: &model.Download{
Model: gorm.Model{ID: 1},
TotalSize: 100,
UserID: 9414,
},
}
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Complete(mockPool))
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)
mockPool.AssertExpectations(t)
}

View File

@@ -1,324 +0,0 @@
package aria2
import (
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
type InstanceMock struct {
testMock.Mock
}
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
args := m.Called(task, options)
return args.Error(0)
}
func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := m.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (m InstanceMock) Cancel(task *model.Download) error {
args := m.Called(task)
return args.Error(0)
}
func (m InstanceMock) Select(task *model.Download, files []int) error {
args := m.Called(task, files)
return args.Error(0)
}
func TestNewMonitor(t *testing.T) {
asserts := assert.New(t)
NewMonitor(&model.Download{GID: "gid"})
_, ok := EventNotifier.Subscribes.Load("gid")
asserts.True(ok)
}
func TestMonitor_Loop(t *testing.T) {
asserts := assert.New(t)
notifier := make(chan StatusEvent)
MAX_RETRY = 0
monitor := &Monitor{
Task: &model.Download{GID: "gid"},
Interval: time.Duration(1) * time.Second,
notifier: notifier,
}
asserts.NotPanics(func() {
monitor.Loop()
})
}
func TestMonitor_Update(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
GID: "gid",
Parent: "TestMonitor_Update",
},
Interval: time.Duration(1) * time.Second,
}
// 无法获取状态
{
MAX_RETRY = 1
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close()
Instance = testInstance
asserts.False(monitor.Update())
asserts.True(monitor.Update())
testInstance.AssertExpectations(t)
asserts.False(util.Exists("TestMonitor_Update"))
}
// 磁力链下载重定向
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"1"},
}, nil)
monitor.Task.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
asserts.EqualValues("1", monitor.Task.GID)
}
// 无法更新任务信息
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回未知状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回被取消状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回活跃状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.False(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回错误状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
// 返回完成
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
Instance = testInstance
asserts.True(monitor.Update())
asserts.NoError(mock.ExpectationsWereMet())
testInstance.AssertExpectations(t)
}
}
func TestMonitor_UpdateTaskInfo(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_UpdateTaskInfo",
},
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 更新成功,无需校验
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 更新成功,大小改变,需要校验,校验失败
{
testInstance := new(InstanceMock)
testInstance.On("Cancel", testMock.Anything).Return(nil)
Instance = testInstance
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
testInstance.AssertExpectations(t)
}
}
func TestMonitor_ValidateFile(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_ValidateFile",
},
}
// 无法创建文件系统
{
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "unknown",
},
}
asserts.Error(monitor.ValidateFile())
}
// 文件大小超出容量配额
{
cache.Set("pack_size_0", uint64(0), 0)
monitor.Task.TotalSize = 11
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "mock",
},
Group: model.Group{
MaxStorage: 10,
},
}
asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile())
}
// 单文件大小超出容量配额
{
cache.Set("pack_size_0", uint64(0), 0)
monitor.Task.TotalSize = 10
monitor.Task.StatusInfo.Files = []rpc.FileInfo{
{
Selected: "true",
Length: "6",
},
}
monitor.Task.User = &model.User{
Policy: model.Policy{
Type: "mock",
MaxSize: 5,
},
Group: model.Group{
MaxStorage: 10,
},
}
asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile())
}
}
func TestMonitor_Complete(t *testing.T) {
asserts := assert.New(t)
monitor := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
GID: "gid",
Parent: "TestMonitor_Complete",
StatusInfo: rpc.StatusInfo{
Files: []rpc.FileInfo{
{
Selected: "true",
Path: "TestMonitor_Complete",
},
},
},
},
}
cache.Set("setting_max_worker_num", "1", 0)
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
task.Init()
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.True(monitor.Complete(rpc.StatusInfo{}))
asserts.NoError(mock.ExpectationsWereMet())
}

View File

@@ -1,64 +0,0 @@
package aria2
import (
"sync"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
)
// Notifier aria2实践通知处理
type Notifier struct {
Subscribes sync.Map
}
// Subscribe 订阅事件通知
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
notifier.Subscribes.Store(gid, target)
}
// Unsubscribe 取消订阅事件通知
func (notifier *Notifier) Unsubscribe(gid string) {
notifier.Subscribes.Delete(gid)
}
// Notify 发送通知
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
for _, event := range events {
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
target.(chan StatusEvent) <- StatusEvent{
GID: event.Gid,
Status: status,
}
}
}
}
// OnDownloadStart 下载开始
func (notifier *Notifier) OnDownloadStart(events []rpc.Event) {
notifier.Notify(events, Downloading)
}
// OnDownloadPause 下载暂停
func (notifier *Notifier) OnDownloadPause(events []rpc.Event) {
notifier.Notify(events, Paused)
}
// OnDownloadStop 下载停止
func (notifier *Notifier) OnDownloadStop(events []rpc.Event) {
notifier.Notify(events, Canceled)
}
// OnDownloadComplete 下载完成
func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) {
notifier.Notify(events, Complete)
}
// OnDownloadError 下载出错
func (notifier *Notifier) OnDownloadError(events []rpc.Event) {
notifier.Notify(events, Error)
}
// OnBtDownloadComplete BT下载完成
func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) {
notifier.Notify(events, Complete)
}

View File

@@ -1,52 +0,0 @@
package aria2
import (
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
)
func TestNotifier_Notify(t *testing.T) {
asserts := assert.New(t)
notifier2 := &Notifier{}
notifyChan := make(chan StatusEvent, 10)
notifier2.Subscribe(notifyChan, "1")
// 未订阅
{
notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
asserts.Len(notifyChan, 0)
}
// 订阅
{
notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
asserts.Len(notifyChan, 1)
<-notifyChan
}
}

View File

@@ -2,9 +2,11 @@ package auth
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
@@ -15,8 +17,10 @@ import (
)
var (
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
)
// General 通用的认证接口
@@ -30,9 +34,8 @@ type Auth interface {
Check(body string, sign string) error
}
// SignRequest 对PUT\POST等复杂HTTP请求签名如果请求Header中
// 包含 X-Policy 则此请求会被认定为上传请求只会对URI部分和
// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。
// SignRequest 对PUT\POST等复杂HTTP请求签名只会对URI部分、
// 请求正文、`X-Cr-`开头的header进行签名
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
// 处理有效期
if expires > 0 {
@@ -54,27 +57,38 @@ func CheckRequest(instance Auth, r *http.Request) error {
ok bool
)
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
return ErrAuthFailed
return ErrAuthHeaderMissing
}
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
return instance.Check(getSignContent(r), sign[0])
}
// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求
// 返回待签名/验证的字符串
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果 Header 中包含 `X-Policy`
// 则不对正文签名。返回待签名/验证的字符串
func getSignContent(r *http.Request) (rawSignString string) {
if policy, ok := r.Header["X-Policy"]; ok {
rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "")
} else {
var body = []byte{}
// 读取所有body正文
var body = []byte{}
if _, ok := r.Header["X-Cr-Policy"]; !ok {
if r.Body != nil {
body, _ = ioutil.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = ioutil.NopCloser(bytes.NewReader(body))
}
rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body))
}
// 决定要签名的header
var signedHeader []string
for k, _ := range r.Header {
if strings.HasPrefix(k, "X-Cr-") && k != "X-Cr-Filename" {
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
}
}
sort.Strings(signedHeader)
// 读取所有待签名Header
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
return rawSignString
}

View File

@@ -70,7 +70,7 @@ func TestSignRequest(t *testing.T) {
strings.NewReader("I am body."),
)
asserts.NoError(err)
req.Header["X-Policy"] = []string{"I am Policy"}
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
req = SignRequest(General, req, 10)
asserts.NotEmpty(req.Header["Authorization"])
}
@@ -80,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 缺少请求头
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
err = CheckRequest(General, req)
asserts.Error(err)
asserts.Equal(ErrAuthHeaderMissing, err)
}
// 非上传请求 验证成功
{
req, err := http.NewRequest(
@@ -101,7 +114,7 @@ func TestCheckRequest(t *testing.T) {
strings.NewReader("I am body."),
)
asserts.NoError(err)
req.Header["X-Policy"] = []string{"I am Policy"}
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
req = SignRequest(General, req, 0)
err = CheckRequest(General, req)
asserts.NoError(err)

View File

@@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
signSlice := strings.Split(sign, ":")
// 如果未携带expires字段
if signSlice[len(signSlice)-1] == "" {
return ErrAuthFailed
return ErrExpiresMissing
}
// 验证是否过期

15
pkg/balancer/balancer.go Normal file
View File

@@ -0,0 +1,15 @@
package balancer
type Balancer interface {
NextPeer(nodes interface{}) (error, interface{})
}
// NewBalancer 根据策略标识返回新的负载均衡器
func NewBalancer(strategy string) Balancer {
switch strategy {
case "RoundRobin":
return &RoundRobin{}
default:
return &RoundRobin{}
}
}

View File

@@ -0,0 +1,12 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewBalancer(t *testing.T) {
a := assert.New(t)
a.NotNil(NewBalancer(""))
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
}

8
pkg/balancer/errors.go Normal file
View File

@@ -0,0 +1,8 @@
package balancer
import "errors"
var (
ErrInputNotSlice = errors.New("Input value is not silice")
ErrNoAvaliableNode = errors.New("No nodes avaliable")
)

View File

@@ -0,0 +1,30 @@
package balancer
import (
"reflect"
"sync/atomic"
)
type RoundRobin struct {
current uint64
}
// NextPeer 返回轮盘的下一节点
func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
v := reflect.ValueOf(nodes)
if v.Kind() != reflect.Slice {
return ErrInputNotSlice, nil
}
if v.Len() == 0 {
return ErrNoAvaliableNode, nil
}
next := r.NextIndex(v.Len())
return nil, v.Index(next).Interface()
}
// NextIndex 返回下一个节点下标
func (r *RoundRobin) NextIndex(total int) int {
return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total))
}

View File

@@ -0,0 +1,42 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestRoundRobin_NextIndex(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
total := 5
for i := 1; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
for i := 0; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
}
func TestRoundRobin_NextPeer(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
// not slice
{
err, _ := r.NextPeer("s")
a.Equal(ErrInputNotSlice, err)
}
// no nodes
{
err, _ := r.NextPeer([]string{})
a.Equal(ErrNoAvaliableNode, err)
}
// pass
{
err, res := r.NextPeer([]string{"a"})
a.NoError(err)
a.Equal("a", res.(string))
}
}

209
pkg/cluster/controller.go Normal file
View File

@@ -0,0 +1,209 @@
package cluster
import (
"bytes"
"encoding/gob"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
"net/url"
"sync"
)
var DefaultController Controller
// Controller controls communications between master and slave
type Controller interface {
// Handle heartbeat sent from master
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
// Get Aria2 Instance by master node ID
GetAria2Instance(string) (common.Aria2, error)
// Send event change message to master node
SendNotification(string, string, mq.Message) error
// Submit async task into task pool
SubmitTask(string, interface{}, string, func(interface{})) error
// Get master node info
GetMasterInfo(string) (*MasterInfo, error)
// Get master OneDrive policy credential
GetOneDriveToken(string, uint) (string, error)
}
type slaveController struct {
masters map[string]MasterInfo
lock sync.RWMutex
}
// info of master node
type MasterInfo struct {
ID string
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance Node
Client request.Client
jobTracker map[string]bool
}
func InitController() {
DefaultController = &slaveController{
masters: make(map[string]MasterInfo),
}
gob.Register(rpc.StatusInfo{})
}
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
c.lock.Lock()
defer c.lock.Unlock()
req.Node.AfterFind()
// close old node if exist
origin, ok := c.masters[req.SiteID]
if (ok && req.IsUpdate) || !ok {
if ok {
origin.Instance.Kill()
}
masterUrl, err := url.Parse(req.SiteURL)
if err != nil {
return serializer.NodePingResp{}, err
}
c.masters[req.SiteID] = MasterInfo{
ID: req.SiteID,
URL: masterUrl,
TTL: req.CredentialTTL,
Client: request.NewClient(
request.WithEndpoint(masterUrl.String()),
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
request.WithCredential(auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
}, int64(req.CredentialTTL)),
),
jobTracker: make(map[string]bool),
Instance: NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
}),
}
}
return serializer.NodePingResp{}, nil
}
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return node.Instance.GetAria2Instance(), nil
}
return nil, ErrMasterNotFound
}
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
body := bytes.Buffer{}
enc := gob.NewEncoder(&body)
if err := enc.Encode(&msg); err != nil {
return err
}
res, err := node.Client.Request(
"PUT",
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
&body,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
c.lock.RUnlock()
return ErrMasterNotFound
}
// SubmitTask 提交异步任务
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
if _, ok := node.jobTracker[hash]; ok {
// 任务已存在,直接返回
return nil
}
node.jobTracker[hash] = true
submitter(job)
return nil
}
return ErrMasterNotFound
}
// GetMasterInfo 获取主机节点信息
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return &node, nil
}
return nil, ErrMasterNotFound
}
// GetOneDriveToken 获取主机OneDrive凭证
func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
res, err := node.Client.Request(
"GET",
fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID),
nil,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), nil
}
c.lock.RUnlock()
return "", ErrMasterNotFound
}

View File

@@ -0,0 +1,385 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestInitController(t *testing.T) {
assert.NotPanics(t, func() {
InitController()
})
}
func TestSlaveController_HandleHeartBeat(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: make(map[string]MasterInfo),
}
// first heart beat
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
Node: &model.Node{},
})
a.NoError(err)
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "2",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
}
// second heart beat, no fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Empty(c.masters["1"].URL)
}
// second heart beat, fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
// second heart beat, fresh, url illegal
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: string([]byte{0x7f}),
Node: &model.Node{},
})
a.Error(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
}
type nodeMock struct {
testMock.Mock
}
func (n nodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n nodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n nodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n nodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n nodeMock) Kill() {
n.Called()
}
func (n nodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
func TestSlaveController_GetAria2Instance(t *testing.T) {
a := assert.New(t)
mockNode := &nodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Instance: mockNode},
},
}
// node node found
{
res, err := c.GetAria2Instance("2")
a.Nil(res)
a.Equal(ErrMasterNotFound, err)
}
// node found
{
res, err := c.GetAria2Instance("1")
a.NotNil(res)
a.NoError(err)
mockNode.AssertExpectations(t)
}
}
type requestMock struct {
testMock.Mock
}
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}
func TestSlaveController_SendNotification(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
}
// gob encode error
{
type randomType struct{}
a.Error(c.SendNotification("1", "", mq.Message{
Content: randomType{},
}))
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Error(c.SendNotification("1", "s1", mq.Message{}))
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
mockRequest.AssertExpectations(t)
}
}
func TestSlaveController_SubmitTask(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {
jobTracker: map[string]bool{},
},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
}
// success
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.True(submitted)
}
// job already submitted
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.False(submitted)
}
}
func TestSlaveController_GetMasterInfo(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetMasterInfo("2")
a.Equal(ErrMasterNotFound, err)
a.Nil(res)
}
// success
{
res, err := c.GetMasterInfo("1")
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveController_GetOneDriveToken(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetOneDriveToken("2", 1)
a.Equal(ErrMasterNotFound, err)
a.Empty(res)
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetOneDriveToken("1", 1)
a.Error(err)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetOneDriveToken("1", 1)
a.Equal(1, err.(serializer.AppError).Code)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetOneDriveToken("1", 1)
a.NoError(err)
a.Equal("expected", res)
mockRequest.AssertExpectations(t)
}
}

12
pkg/cluster/errors.go Normal file
View File

@@ -0,0 +1,12 @@
package cluster
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)

272
pkg/cluster/master.go Normal file
View File

@@ -0,0 +1,272 @@
package cluster
import (
"context"
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
const (
deleteTempFileDuration = 60 * time.Second
statusRetryDuration = 10 * time.Second
)
type MasterNode struct {
Model *model.Node
aria2RPC rpcService
lock sync.RWMutex
}
// RPCService 通过RPC服务的Aria2任务管理器
type rpcService struct {
Caller rpc.Client
Initialized bool
retryDuration time.Duration
deletePaddingDuration time.Duration
parent *MasterNode
options *clientOptions
}
type clientOptions struct {
Options map[string]interface{} // 创建下载时额外添加的设置
}
// Init 初始化节点
func (node *MasterNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
node.aria2RPC.parent = node
node.aria2RPC.retryDuration = statusRetryDuration
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
node.lock.Unlock()
node.lock.RLock()
if node.Model.Aria2Enabled {
node.lock.RUnlock()
node.aria2RPC.Init()
return
}
node.lock.RUnlock()
}
func (node *MasterNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
return &serializer.NodePingResp{}, nil
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
func (node *MasterNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
}
// IsActive 返回节点是否在线
func (node *MasterNode) IsActive() bool {
return true
}
// Kill 结束aria2请求
func (node *MasterNode) Kill() {
if node.aria2RPC.Caller != nil {
node.aria2RPC.Caller.Close()
}
}
// GetAria2Instance 获取主机Aria2实例
func (node *MasterNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
if !node.Model.Aria2Enabled {
node.lock.RUnlock()
return &common.DummyAria2{}
}
if !node.aria2RPC.Initialized {
node.lock.RUnlock()
node.aria2RPC.Init()
return &common.DummyAria2{}
}
defer node.lock.RUnlock()
return &node.aria2RPC
}
func (node *MasterNode) IsMater() bool {
return true
}
func (node *MasterNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
func (r *rpcService) Init() error {
r.parent.lock.Lock()
defer r.parent.lock.Unlock()
r.Initialized = false
// 客户端已存在,则关闭先前连接
if r.Caller != nil {
r.Caller.Close()
}
// 解析RPC服务地址
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
if err != nil {
util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err)
return err
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
if err != nil {
util.Log().Warning("无法解析主机 Aria2 配置,%s", err)
return err
}
}
r.options = &clientOptions{
Options: globalOptions,
}
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
r.Caller = caller
r.Initialized = err == nil
return err
}
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
r.parent.lock.RLock()
// 生成存储路径
guid, _ := uuid.NewV4()
path := filepath.Join(
r.parent.Model.Aria2OptionsSerialized.TempPath,
"aria2",
guid.String(),
)
r.parent.lock.RUnlock()
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range r.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := r.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return "", err
}
return gid, nil
}
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := r.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("无法获取离线下载状态,%s稍后重试", err)
time.Sleep(r.retryDuration)
res, err = r.Caller.TellStatus(task.GID)
}
return res, err
}
func (r *rpcService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := r.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
return err
}
func (r *rpcService) Select(task *model.Download, files []int) error {
var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
func (r *rpcService) GetConfig() model.Aria2Option {
r.parent.lock.RLock()
defer r.parent.lock.RUnlock()
return r.parent.Model.Aria2OptionsSerialized
}
func (s *rpcService) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
// 避免被aria2占用异步执行删除
go func(d time.Duration, src string) {
time.Sleep(d)
err := os.RemoveAll(src)
if err != nil {
util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err)
}
}(s.deletePaddingDuration, task.Parent)
return nil
}

186
pkg/cluster/master_test.go Normal file
View File

@@ -0,0 +1,186 @@
package cluster
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/stretchr/testify/assert"
"os"
"testing"
"time"
)
func TestMasterNode_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{}
m.Init(&model.Node{Status: model.NodeSuspend})
a.Equal(model.NodeSuspend, m.DBModel().Status)
m.Init(&model.Node{Aria2Enabled: true})
}
func TestMasterNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
a.True(m.IsActive())
a.True(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestMasterNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestMasterNode_Kill(t *testing.T) {
m := &MasterNode{
Model: &model.Node{},
}
m.Kill()
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
m.Kill()
}
func TestMasterNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
m.aria2RPC.Initialized = true
a.NotNil(m.GetAria2Instance())
}
func TestRpcService_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{
Options: "{",
},
},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
// failed to decode address
{
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
a.Error(m.aria2RPC.Init())
}
// failed to decode options
{
m.Model.Aria2OptionsSerialized.Server = ""
a.Error(m.aria2RPC.Init())
}
// failed to initialized
{
m.Model.Aria2OptionsSerialized.Server = ""
m.Model.Aria2OptionsSerialized.Options = "{}"
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
a.Error(m.aria2RPC.Init())
a.False(m.aria2RPC.Initialized)
}
}
func getTestRPCNode() *MasterNode {
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{},
},
aria2RPC: rpcService{
options: &clientOptions{
Options: map[string]interface{}{"1": "1"},
},
},
}
m.aria2RPC.parent = m
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
return m
}
func TestRpcService_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.Status(&model.Download{})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.Error(m.aria2RPC.Cancel(&model.Download{}))
}
func TestRpcService_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.NotNil(m.aria2RPC.GetConfig())
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
}
func TestRpcService_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
fdName := "TestRpcService_DeleteTempFile"
a.NoError(os.Mkdir(fdName, 0644))
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
time.Sleep(500 * time.Millisecond)
a.False(util.Exists(fdName))
}

60
pkg/cluster/node.go Normal file
View File

@@ -0,0 +1,60 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
type Node interface {
// Init a node from database model
Init(node *model.Node)
// Check if given feature is enabled
IsFeatureEnabled(feature string) bool
// Subscribe node status change to a callback function
SubscribeStatusChange(callback func(isActive bool, id uint))
// Ping the node
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
// Returns if the node is active
IsActive() bool
// Get instances for aria2 calls
GetAria2Instance() common.Aria2
// Returns unique id of this node
ID() uint
// Kill node and recycle resources
Kill()
// Returns if current node is master node
IsMater() bool
// Get auth instance used to check RPC call from slave to master
MasterAuthInstance() auth.Auth
// Get auth instance used to check RPC call from master to slave
SlaveAuthInstance() auth.Auth
// Get node DB model
DBModel() *model.Node
}
// Create new node from DB model
func NewNodeFromDBModel(node *model.Node) Node {
switch node.Type {
case model.SlaveNodeType:
slave := &SlaveNode{}
slave.Init(node)
return slave
default:
master := &MasterNode{}
master.Init(node)
return master
}
}

17
pkg/cluster/node_test.go Normal file
View File

@@ -0,0 +1,17 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewNodeFromDBModel(t *testing.T) {
a := assert.New(t)
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
Type: model.SlaveNodeType,
}))
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
Type: model.MasterNodeType,
}))
}

190
pkg/cluster/pool.go Normal file
View File

@@ -0,0 +1,190 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"sync"
)
var Default *NodePool
// 需要分类的节点组
var featureGroup = []string{"aria2"}
// Pool 节点池
type Pool interface {
// Returns active node selected by given feature and load balancer
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
// Add given node into pool. If node existed, refresh node.
Add(node *model.Node)
// Delete and kill node from pool by given node id
Delete(id uint)
}
// NodePool 通用节点池
type NodePool struct {
active map[uint]Node
inactive map[uint]Node
featureMap map[string][]Node
lock sync.RWMutex
}
// Init 初始化从机节点池
func Init() {
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("节点池初始化失败, %s", err)
}
}
func (pool *NodePool) Init() {
pool.lock.Lock()
defer pool.lock.Unlock()
pool.featureMap = make(map[string][]Node)
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
}
func (pool *NodePool) buildIndexMap() {
pool.lock.Lock()
for _, feature := range featureGroup {
pool.featureMap[feature] = make([]Node, 0)
}
for _, v := range pool.active {
for _, feature := range featureGroup {
if v.IsFeatureEnabled(feature) {
pool.featureMap[feature] = append(pool.featureMap[feature], v)
}
}
}
pool.lock.Unlock()
}
func (pool *NodePool) GetNodeByID(id uint) Node {
pool.lock.RLock()
defer pool.lock.RUnlock()
if node, ok := pool.active[id]; ok {
return node
}
return pool.inactive[id]
}
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive)
var node Node
pool.lock.Lock()
if n, ok := pool.inactive[id]; ok {
node = n
delete(pool.inactive, id)
} else {
node = pool.active[id]
delete(pool.active, id)
}
if isActive {
pool.active[id] = node
} else {
pool.inactive[id] = node
}
pool.lock.Unlock()
pool.buildIndexMap()
}
func (pool *NodePool) initFromDB() error {
nodes, err := model.GetNodesByStatus(model.NodeActive)
if err != nil {
return err
}
pool.lock.Lock()
for i := 0; i < len(nodes); i++ {
pool.add(&nodes[i])
}
pool.lock.Unlock()
pool.buildIndexMap()
return nil
}
func (pool *NodePool) add(node *model.Node) {
newNode := NewNodeFromDBModel(node)
if newNode.IsActive() {
pool.active[node.ID] = newNode
} else {
pool.inactive[node.ID] = newNode
}
// 订阅节点状态变更
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
pool.nodeStatusChange(isActive, id)
})
}
func (pool *NodePool) Add(node *model.Node) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
var (
old Node
ok bool
)
if old, ok = pool.active[node.ID]; !ok {
old, ok = pool.inactive[node.ID]
}
if old != nil {
go old.Init(node)
return
}
pool.add(node)
}
func (pool *NodePool) Delete(id uint) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
if node, ok := pool.active[id]; ok {
node.Kill()
delete(pool.active, id)
return
}
if node, ok := pool.inactive[id]; ok {
node.Kill()
delete(pool.inactive, id)
return
}
}
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
pool.lock.RLock()
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
err, res := lb.NextPeer(nodes)
if err == nil {
return nil, res.(Node)
}
return err, nil
}
return ErrFeatureNotExist, nil
}

161
pkg/cluster/pool_test.go Normal file
View File

@@ -0,0 +1,161 @@
package cluster
import (
"database/sql"
"errors"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestInitFailed(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestInitSuccess(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestNodePool_GetNodeByID(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
mockNode := &nodeMock{}
// inactive
{
p.inactive[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
// active
{
delete(p.inactive, 1)
p.active[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
}
func TestNodePool_NodeStatusChange(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
n := &MasterNode{Model: &model.Node{}}
p.Init()
p.inactive[1] = n
p.nodeStatusChange(true, 1)
a.Len(p.inactive, 0)
a.Equal(n, p.active[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
}
func TestNodePool_Add(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// new node
{
p.Add(&model.Node{})
a.Len(p.active, 1)
}
// old node
{
p.inactive[0] = p.active[0]
delete(p.active, 0)
p.Add(&model.Node{})
a.Len(p.active, 0)
a.Len(p.inactive, 1)
}
}
func TestNodePool_Delete(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// active
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.active[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
p.Init()
// inactive
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.inactive[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
}
func TestNodePool_BalanceNodeByFeature(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// success
{
p.featureMap["test"] = []Node{&MasterNode{}}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.NoError(err)
a.Equal(p.featureMap["test"][0], res)
}
// NoNodes
{
p.featureMap["test"] = []Node{}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
// No match feature
{
err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
}

410
pkg/cluster/slave.go Normal file
View File

@@ -0,0 +1,410 @@
package cluster
import (
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/url"
"strings"
"sync"
"time"
)
type SlaveNode struct {
Model *model.Node
Active bool
caller slaveCaller
callback func(bool, uint)
close chan bool
lock sync.RWMutex
}
type slaveCaller struct {
parent *SlaveNode
Client request.Client
}
// Init 初始化节点
func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
// Init http request client
var endpoint *url.URL
if serverURL, err := url.Parse(node.Model.Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
node.caller.Client = request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
request.WithEndpoint(endpoint.String()),
)
node.caller.parent = node
if node.close != nil {
node.lock.Unlock()
node.close <- true
go node.StartPingLoop()
} else {
node.Active = true
node.lock.Unlock()
go node.StartPingLoop()
}
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
node.lock.Lock()
node.callback = callback
node.lock.Unlock()
}
// Ping 从机节点,返回从机负载
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
node.lock.RLock()
defer node.lock.RUnlock()
reqBodyEncoded, err := json.Marshal(req)
if err != nil {
return nil, err
}
bodyReader := strings.NewReader(string(reqBodyEncoded))
resp, err := node.caller.Client.Request(
"POST",
"heartbeat",
bodyReader,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return nil, err
}
// 处理列取结果
if resp.Code != 0 {
return nil, serializer.NewErrorFromResponse(resp)
}
var res serializer.NodePingResp
if resStr, ok := resp.Data.(string); ok {
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return nil, err
}
}
return &res, nil
}
// IsActive 返回节点是否在线
func (node *SlaveNode) IsActive() bool {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Active
}
// Kill 结束节点内相关循环
func (node *SlaveNode) Kill() {
node.lock.RLock()
defer node.lock.RUnlock()
if node.close != nil {
close(node.close)
}
}
// GetAria2Instance 获取从机Aria2实例
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
defer node.lock.RUnlock()
if !node.Model.Aria2Enabled {
return &common.DummyAria2{}
}
return &node.caller
}
func (node *SlaveNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *SlaveNode) StartPingLoop() {
node.lock.Lock()
node.close = make(chan bool)
node.lock.Unlock()
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
pingTicker := time.Duration(0)
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
retry := 0
recoverMode := false
isFirstLoop := true
loop:
for {
select {
case <-time.After(pingTicker):
if pingTicker == 0 {
pingTicker = tickDuration
}
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil {
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
retry++
if retry >= model.GetIntSetting("slave_node_retry", 3) {
util.Log().Debug("从机节点 [%s] Ping 重试已达到最大限制,将从机节点标记为不可用", node.Model.Name)
node.changeStatus(false)
if !recoverMode {
// 启动恢复监控循环
util.Log().Debug("从机节点 [%s] 进入恢复模式", node.Model.Name)
pingTicker = recoverDuration
recoverMode = true
}
}
} else {
if recoverMode {
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
pingTicker = tickDuration
recoverMode = false
isFirstLoop = true
}
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
node.changeStatus(true)
retry = 0
}
case <-node.close:
util.Log().Debug("从机节点 [%s] 收到关闭信号", node.Model.Name)
break loop
}
}
}
func (node *SlaveNode) IsMater() bool {
return false
}
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
func (node *SlaveNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
SiteURL: model.GetSiteURL().String(),
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
}
}
func (node *SlaveNode) changeStatus(isActive bool) {
node.lock.RLock()
id := node.Model.ID
if isActive != node.Active {
node.lock.RUnlock()
node.lock.Lock()
node.Active = isActive
node.lock.Unlock()
node.callback(isActive, id)
} else {
node.lock.RUnlock()
}
}
func (s *slaveCaller) Init() error {
return nil
}
// SendAria2Call send remote aria2 call to slave node
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
reqReader, err := getAria2RequestBody(body)
if err != nil {
return nil, err
}
return s.Client.Request(
"POST",
"aria2/"+scope,
reqReader,
).CheckHTTPResponse(200).DecodeResponse()
}
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
GroupOptions: options,
}
res, err := s.SendAria2Call(req, "task")
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), err
}
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "status")
if err != nil {
return rpc.StatusInfo{}, err
}
if res.Code != 0 {
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
}
var status rpc.StatusInfo
res.GobDecode(&status)
return status, err
}
func (s *slaveCaller) Cancel(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "cancel")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) Select(task *model.Download, files []int) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
Files: files,
}
res, err := s.SendAria2Call(req, "select")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) GetConfig() model.Aria2Option {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
return s.parent.Model.Aria2OptionsSerialized
}
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "delete")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
reqBodyEncoded, err := json.Marshal(body)
if err != nil {
return nil, err
}
return strings.NewReader(string(reqBodyEncoded)), nil
}

443
pkg/cluster/slave_test.go Normal file
View File

@@ -0,0 +1,443 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
)
func TestSlaveNode_InitAndKill(t *testing.T) {
a := assert.New(t)
n := &SlaveNode{
callback: func(b bool, u uint) {
},
}
a.NotPanics(func() {
n.Init(&model.Node{})
time.Sleep(time.Millisecond * 500)
n.Init(&model.Node{})
n.Kill()
})
}
func TestSlaveNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
a.Equal(m.Model.ID, m.DBModel().ID)
a.False(m.IsActive())
a.False(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestSlaveNode_Ping(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
// master return error code
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
a.Equal(1, err.(serializer.AppError).Code)
}
// return unexpected json
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
}
// return success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
a.NotNil(m.GetAria2Instance())
}
func TestSlaveNode_StartPingLoop(t *testing.T) {
callbackCount := 0
finishedChan := make(chan struct{})
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m := &SlaveNode{
Active: true,
Model: &model.Node{},
callback: func(b bool, u uint) {
callbackCount++
if callbackCount == 2 {
close(finishedChan)
}
if callbackCount == 1 {
mockRequest.AssertExpectations(t)
mockRequest = requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
}
},
}
cache.Set("setting_slave_ping_interval", "0", 0)
cache.Set("setting_slave_recover_interval", "0", 0)
cache.Set("setting_slave_node_retry", "1", 0)
m.caller.Client = &mockRequest
go func() {
select {
case <-finishedChan:
m.Kill()
}
}()
m.StartPingLoop()
mockRequest.AssertExpectations(t)
}
func TestSlaveNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestSlaveNode_ChangeStatus(t *testing.T) {
a := assert.New(t)
isActive := false
m := &SlaveNode{
Model: &model.Node{},
callback: func(b bool, u uint) {
isActive = b
},
}
a.NotPanics(func() {
m.changeStatus(false)
})
m.changeStatus(true)
a.True(isActive)
}
func getTestRPCNodeSlave() *SlaveNode {
m := &SlaveNode{
Model: &model.Node{},
}
m.caller.parent = m
return m
}
func TestSlaveCaller_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Equal("res", res)
a.NoError(err)
}
}
func TestSlaveCaller_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.NoError(err)
}
}
func TestSlaveCaller_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.NoError(err)
}
}
func TestSlaveCaller_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.NoError(err)
}
}
func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.NoError(err)
}
}

View File

@@ -70,9 +70,13 @@ type redis struct {
// 缩略图 配置
type thumb struct {
MaxWidth uint
MaxHeight uint
FileSuffix string `validate:"min=1"`
MaxWidth uint
MaxHeight uint
FileSuffix string `validate:"min=1"`
MaxTaskCount int
EncodeMethod string `validate:"eq=jpg|eq=png"`
EncodeQuality int `validate:"gte=1,lte=100"`
GCAfterGen bool
}
// 跨域配置

View File

@@ -44,16 +44,20 @@ var CaptchaConfig = &captcha{
var CORSConfig = &cors{
AllowOrigins: []string{"UNSET"},
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
AllowHeaders: []string{"Cookie", "X-Policy", "Authorization", "Content-Length", "Content-Type", "X-Path", "X-FileName"},
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Path", "X-FileName"},
AllowCredentials: false,
ExposeHeaders: nil,
}
// ThumbConfig 缩略图配置
var ThumbConfig = &thumb{
MaxWidth: 400,
MaxHeight: 300,
FileSuffix: "._thumb",
MaxWidth: 400,
MaxHeight: 300,
FileSuffix: "._thumb",
MaxTaskCount: -1,
EncodeMethod: "jpg",
GCAfterGen: false,
EncodeQuality: 85,
}
// SlaveConfig 从机配置

View File

@@ -1,13 +1,13 @@
package conf
// BackendVersion 当前后端版本号
var BackendVersion = "3.3.2"
var BackendVersion = "3.4.0"
// RequiredDBVersion 与当前版本匹配的数据库版本
var RequiredDBVersion = "3.3.2"
var RequiredDBVersion = "3.4.0"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
var RequiredStaticVersion = "3.3.2"
var RequiredStaticVersion = "3.4.0"
// IsPro 是否为Pro版本
var IsPro = "false"

View File

@@ -0,0 +1,39 @@
package driver
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io"
"net/url"
)
// Handler 存储策略适配器
type Handler interface {
// 上传文件, dst为文件存储路径size 为文件大小。上下文关闭
// 时,应取消上传并清理临时文件
Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件内容
Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图可直接在ContentResponse中返回文件数据流也可指
// 定为重定向
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链/下载地址,
// url - 站点本身地址,
// isDownload - 是否直接下载
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名同时回调会话有效期为sessionTTL
Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error)
// List 递归列取远程端path路径下文件、目录不包含path本身
// 返回的对象路径以path作为起始根目录.
// recursive - 是否递归列出
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
}

View File

@@ -190,7 +190,7 @@ func TestHandler_Source(t *testing.T) {
// 设定了CDN解析失败
{
handler.Policy.BaseURL = string(0x7f)
handler.Policy.BaseURL = string([]byte{0x7f})
file := model.File{
Model: gorm.Model{
ID: 1,

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"io"
"io/ioutil"
"net/http"
@@ -573,7 +574,7 @@ func sysError(err error) *RespError {
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, *RespError) {
// 获取凭证
err := client.UpdateCredential(ctx)
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
if err != nil {
return "", sysError(err)
}

View File

@@ -2,6 +2,7 @@ package onedrive
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
@@ -28,7 +29,8 @@ type Client struct {
ClientSecret string
Redirect string
Request request.Client
Request request.Client
ClusterController cluster.Controller
}
// Endpoints OneDrive客户端相关设置
@@ -51,11 +53,12 @@ func NewClient(policy *model.Policy) (*Client, error) {
Credential: &Credential{
RefreshToken: policy.AccessKey,
},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OdRedirect,
Request: request.HTTPClient{},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OdRedirect,
Request: request.NewClient(),
ClusterController: cluster.DefaultController,
}
if client.Endpoints.DriverResource == "" {

View File

@@ -14,6 +14,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
@@ -27,6 +28,16 @@ type Driver struct {
HTTPClient request.Client
}
// NewDriver 从存储策略初始化新的Driver实例
func NewDriver(policy *model.Policy) (driver.Handler, error) {
client, err := NewClient(policy)
return Driver{
Policy: policy,
Client: client,
HTTPClient: request.NewClient(),
}, err
}
// List 列取项目
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")

View File

@@ -0,0 +1,25 @@
package onedrive
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View File

@@ -123,7 +123,14 @@ func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credent
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) UpdateCredential(ctx context.Context) error {
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
if isSlave {
return client.fetchCredentialFromMaster(ctx)
}
GlobalMutex.Lock(client.Policy.ID)
defer GlobalMutex.Unlock(client.Policy.ID)
// 如果已存在凭证
if client.Credential != nil && client.Credential.AccessToken != "" {
// 检查已有凭证是否过期
@@ -160,11 +167,21 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
client.Credential = credential
// 更新存储策略的 RefreshToken
client.Policy.AccessKey = credential.RefreshToken
client.Policy.SaveAndClearCache()
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
// 更新缓存
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
return nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := client.ClusterController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
client.Credential = &Credential{AccessToken: res}
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
"io"
"io/ioutil"
"net/http"
@@ -269,10 +270,10 @@ func TestClient_UpdateCredential(t *testing.T) {
// 无有效的RefreshToken
{
err := client.UpdateCredential(context.Background())
err := client.UpdateCredential(context.Background(), false)
asserts.Equal(ErrInvalidRefreshToken, err)
client.Credential = nil
err = client.UpdateCredential(context.Background())
err = client.UpdateCredential(context.Background(), false)
asserts.Equal(ErrInvalidRefreshToken, err)
}
@@ -299,7 +300,7 @@ func TestClient_UpdateCredential(t *testing.T) {
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := client.UpdateCredential(context.Background())
err := client.UpdateCredential(context.Background(), false)
clientMock.AssertExpectations(t)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
@@ -331,7 +332,7 @@ func TestClient_UpdateCredential(t *testing.T) {
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
}
err := client.UpdateCredential(context.Background())
err := client.UpdateCredential(context.Background(), false)
clientMock.AssertExpectations(t)
asserts.Error(err)
}
@@ -346,7 +347,7 @@ func TestClient_UpdateCredential(t *testing.T) {
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
}
err := client.UpdateCredential(context.Background())
err := client.UpdateCredential(context.Background(), false)
asserts.NoError(err)
asserts.Equal("AccessToken", client.Credential.AccessToken)
asserts.Equal("RefreshToken", client.Credential.RefreshToken)
@@ -359,8 +360,27 @@ func TestClient_UpdateCredential(t *testing.T) {
AccessToken: "AccessToken2",
ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(),
}
err := client.UpdateCredential(context.Background())
err := client.UpdateCredential(context.Background(), false)
asserts.NoError(err)
asserts.Equal("AccessToken2", client.Credential.AccessToken)
}
// slave failed
{
mockController := &controllermock.SlaveControllerMock{}
mockController.On("GetOneDriveToken", testMock.Anything, testMock.Anything).Return("", errors.New("error"))
client.ClusterController = mockController
err := client.UpdateCredential(context.Background(), true)
asserts.Error(err)
}
// slave success
{
mockController := &controllermock.SlaveControllerMock{}
mockController.On("GetOneDriveToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil)
client.ClusterController = mockController
err := client.UpdateCredential(context.Background(), true)
asserts.NoError(err)
asserts.Equal("AccessToken3", client.Credential.AccessToken)
}
}

View File

@@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
}
// 获取公钥
client := request.HTTPClient{}
client := request.NewClient()
body, err := client.Request("GET", string(pubURL), nil).
CheckHTTPResponse(200).
GetResponse()

View File

@@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) {
BucketName: "test",
Server: "oss-cn-shanghai.aliyuncs.com",
},
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
cache.Set("setting_preview_timeout", "3600", 0)

View File

@@ -49,6 +49,7 @@ func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]
handler.getAPIUrl("list"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return res, err
@@ -97,7 +98,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string {
// Get 获取文件内容
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 尝试获取速度限制 TODO 是否需要在这里限制?
// 尝试获取速度限制
speedLimit := 0
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
speedLimit = user.Group.SpeedLimit
@@ -116,6 +117,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
@@ -151,10 +153,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
}
// 对文件名进行URLEncode
fileName, err := url.QueryUnescape(path.Base(dst))
if err != nil {
return err
}
fileName := url.QueryEscape(path.Base(dst))
// 决定是否要禁用文件覆盖
overwrite := "true"
@@ -168,13 +167,15 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
handler.Policy.GetUploadURL(),
file,
request.WithHeader(map[string][]string{
"Authorization": {credential.Token},
"X-Policy": {credential.Policy},
"X-FileName": {fileName},
"X-Overwrite": {overwrite},
"X-Cr-Policy": {credential.Policy},
"X-Cr-FileName": {fileName},
"X-Cr-Overwrite": {overwrite},
}),
request.WithContentLength(int64(size)),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
request.WithCredential(handler.AuthInstance, int64(credentialTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
@@ -206,6 +207,8 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
handler.getAPIUrl("delete"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
).CheckHTTPResponse(200).GetResponse()
if err != nil {
return files, err
@@ -288,7 +291,7 @@ func (handler Driver) Source(
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path))
signedURI, err = auth.SignURI(
handler.AuthInstance,
fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, fileName),
fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, url.PathEscape(fileName)),
ttl,
)
@@ -329,8 +332,8 @@ func (handler Driver) getUploadCredential(ctx context.Context, policy serializer
// 签名上传策略
uploadRequest, _ := http.NewRequest("POST", "/api/v3/slave/upload", nil)
uploadRequest.Header = map[string][]string{
"X-Policy": {policyEncoded},
"X-Overwrite": {"false"},
"X-Cr-Policy": {policyEncoded},
"X-Cr-Overwrite": {"false"},
}
auth.SignRequest(handler.AuthInstance, uploadRequest, TTL)

View File

@@ -105,7 +105,7 @@ func TestHandler_Source(t *testing.T) {
// 解析失败 自定义CDN
{
handler := Driver{
Policy: &model.Policy{Server: "/", BaseURL: string(0x7f)},
Policy: &model.Policy{Server: "/", BaseURL: string([]byte{0x7f})},
AuthInstance: auth.HMACAuth{},
}
file := model.File{

View File

@@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
}
// 获取文件数据流
client := request.HTTPClient{}
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
@@ -344,6 +344,7 @@ func (handler Driver) Token(ctx context.Context, TTL int64, key string) (seriali
map[string]string{"bucket": handler.Policy.BucketName},
[]string{"starts-with", "$key", savePath},
[]string{"starts-with", "$success_action_redirect", apiURL.String()},
[]string{"starts-with", "$name", ""},
[]string{"starts-with", "$Content-Type", ""},
map[string]string{"x-amz-algorithm": "AWS4-HMAC-SHA256"},
},

View File

@@ -0,0 +1,7 @@
package masterinslave
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
)

View File

@@ -0,0 +1,56 @@
package masterinslave
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io"
"net/url"
)
// Driver 影子存储策略,用于在从机端上传文件
type Driver struct {
master cluster.Node
handler driver.Handler
policy *model.Policy
}
// NewDriver 返回新的处理器
func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
return &Driver{
master: master,
handler: handler,
policy: policy,
}
}
func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
return d.handler.Put(ctx, file, dst, size)
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) {
return serializer.UploadCredential{}, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}

View File

@@ -0,0 +1,9 @@
package slaveinmaster
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result")
)

View File

@@ -0,0 +1,121 @@
package slaveinmaster
import (
"bytes"
"context"
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io"
"net/url"
"time"
)
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
type Driver struct {
node cluster.Node
handler driver.Handler
policy *model.Policy
client request.Client
}
// NewDriver 返回新的从机指派处理器
func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
var endpoint *url.URL
if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return &Driver{
node: node,
handler: handler,
policy: policy,
client: request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)),
request.WithEndpoint(endpoint.String()),
),
}
}
// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略
func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
src, ok := ctx.Value(fsctx.SlaveSrcPath).(string)
if !ok {
return ErrSlaveSrcPathNotExist
}
req := serializer.SlaveTransferReq{
Src: src,
Dst: dst,
Policy: d.policy,
}
body, err := json.Marshal(req)
if err != nil {
return err
}
// 订阅转存结果
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)).
CheckHTTPResponse(200).
DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
// 等待转存结果或者超时
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
select {
case <-time.After(time.Duration(waitTimeout) * time.Second):
return ErrWaitResultTimeout
case msg := <-resChan:
if msg.Event != serializer.SlaveTransferSuccess {
return errors.New(msg.Content.(serializer.SlaveTransferResult).Error)
}
}
return nil
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error) {
return serializer.UploadCredential{}, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}

View File

@@ -228,25 +228,25 @@ func TestFileSystem_deleteGroupedFile(t *testing.T) {
},
}
// 全部失败
// 全部不存在
{
failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
asserts.Equal(map[uint][]string{
1: {"1_1.txt", "1_2.txt"},
2: {"2_1.txt", "2_2.txt"},
3: {"3_1.txt"},
1: {},
2: {},
3: {},
}, failed)
}
// 部分失败
// 部分不存在
{
file, err := os.Create(util.RelativePath("1_1.txt"))
asserts.NoError(err)
_ = file.Close()
failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
asserts.Equal(map[uint][]string{
1: {"1_2.txt"},
2: {"2_1.txt", "2_2.txt"},
3: {"3_1.txt"},
1: {},
2: {},
3: {},
}, failed)
}
// 部分失败,包含整组未知存储策略导致的失败
@@ -259,9 +259,9 @@ func TestFileSystem_deleteGroupedFile(t *testing.T) {
files[3].Policy.Type = "unknown"
failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
asserts.Equal(map[uint][]string{
1: {"1_2.txt"},
1: {},
2: {"2_1.txt", "2_2.txt"},
3: {"3_1.txt"},
3: {},
}, failed)
}
}

View File

@@ -1,8 +1,12 @@
package filesystem
import (
"context"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
"io"
"net/http"
"net/url"
@@ -19,7 +23,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
@@ -43,36 +46,6 @@ type FileHeader interface {
GetVirtualPath() string
}
// Handler 存储策略适配器
type Handler interface {
// 上传文件, dst为文件存储路径size 为文件大小。上下文关闭
// 时,应取消上传并清理临时文件
Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件内容
Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图可直接在ContentResponse中返回文件数据流也可指
// 定为重定向
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链/下载地址,
// url - 站点本身地址,
// isDownload - 是否直接下载
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名同时回调会话有效期为sessionTTL
Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error)
// List 递归列取远程端path路径下文件、目录不包含path本身
// 返回的对象路径以path作为起始根目录.
// recursive - 是否递归列出
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
}
// FileSystem 管理文件的文件系统
type FileSystem struct {
// 文件系统所有者
@@ -96,7 +69,7 @@ type FileSystem struct {
/*
文件系统处理适配器
*/
Handler Handler
Handler driver.Handler
// 回收锁
recycleLock sync.Mutex
@@ -134,7 +107,6 @@ func NewFileSystem(user *model.User) (*FileSystem, error) {
// 分配存储策略适配器
err := fs.DispatchHandler()
// TODO 分配默认钩子
return fs, err
}
@@ -159,7 +131,6 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
}
// DispatchHandler 根据存储策略分配文件适配器
// TODO 完善测试
func (fs *FileSystem) DispatchHandler() error {
var policyType string
var currentPolicy *model.Policy
@@ -184,7 +155,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "remote":
fs.Handler = remote.Driver{
Policy: currentPolicy,
Client: request.HTTPClient{},
Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
}
return nil
@@ -196,7 +167,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "oss":
fs.Handler = oss.Driver{
Policy: currentPolicy,
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
return nil
case "upyun":
@@ -205,13 +176,9 @@ func (fs *FileSystem) DispatchHandler() error {
}
return nil
case "onedrive":
client, err := onedrive.NewClient(currentPolicy)
fs.Handler = onedrive.Driver{
Policy: currentPolicy,
Client: client,
HTTPClient: request.HTTPClient{},
}
return err
var odErr error
fs.Handler, odErr = onedrive.NewDriver(currentPolicy)
return odErr
case "cos":
u, _ := url.Parse(currentPolicy.Server)
b := &cossdk.BaseURL{BucketURL: u}
@@ -223,7 +190,7 @@ func (fs *FileSystem) DispatchHandler() error {
SecretKey: currentPolicy.SecretKey,
},
}),
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
return nil
case "s3":
@@ -272,6 +239,30 @@ func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
return fs, err
}
// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点
func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) {
fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, &fs.User.Policy)
}
// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器
func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) {
switch fs.Policy.Type {
case "remote":
fs.Policy.Type = "local"
fs.DispatchHandler()
case "local":
fs.Policy.Type = "remote"
fs.Policy.Server = masterURL
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
fs.Policy.SecretKey = master.DBModel().MasterKey
fs.DispatchHandler()
case "onedrive":
fs.Policy.MasterID = masterID
}
fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy)
}
// SetTargetFile 设置当前处理的目标文件
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
if len(fs.FileTarget) == 0 {

View File

@@ -1,6 +1,9 @@
package filesystem
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
"net/http/httptest"
"github.com/DATA-DOG/go-sqlmock"
@@ -104,6 +107,10 @@ func TestDispatchHandler(t *testing.T) {
err = fs.DispatchHandler()
asserts.NoError(err)
fs.Policy = &model.Policy{Type: "cos"}
err = fs.DispatchHandler()
asserts.NoError(err)
fs.Policy = &model.Policy{Type: "s3"}
err = fs.DispatchHandler()
asserts.NoError(err)
@@ -262,3 +269,40 @@ func TestFileSystem_SetTargetByInterface(t *testing.T) {
asserts.Len(fs.FileTarget, 1)
}
}
func TestFileSystem_SwitchToSlaveHandler(t *testing.T) {
a := assert.New(t)
fs := FileSystem{
User: &model.User{},
}
mockNode := &cluster.MasterNode{
Model: &model.Node{},
}
fs.SwitchToSlaveHandler(mockNode)
a.IsType(&slaveinmaster.Driver{}, fs.Handler)
}
func TestFileSystem_SwitchToShadowHandler(t *testing.T) {
a := assert.New(t)
fs := FileSystem{
User: &model.User{},
Policy: &model.Policy{},
}
mockNode := &cluster.MasterNode{
Model: &model.Node{},
}
// remote to local
{
fs.Policy.Type = "remote"
fs.SwitchToShadowHandler(mockNode, "", "")
a.IsType(&masterinslave.Driver{}, fs.Handler)
}
// local to remote
{
fs.Policy.Type = "local"
fs.SwitchToShadowHandler(mockNode, "", "")
a.IsType(&masterinslave.Driver{}, fs.Handler)
}
}

View File

@@ -41,4 +41,6 @@ const (
ValidateCapacityOnceCtx
// 禁止上传时同名覆盖操作
DisableOverwrite
// 文件在从机节点中的路径
SlaveSrcPath
)

View File

@@ -708,3 +708,29 @@ func TestHookGiveBackCapacity(t *testing.T) {
asserts.EqualValues(7, fs.User.Storage)
}
}
func TestHookValidateCapacityWithoutIncrease(t *testing.T) {
a := assert.New(t)
fs := &FileSystem{
User: &model.User{
Model: gorm.Model{ID: 1},
Storage: 10,
Group: model.Group{},
},
}
ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{Size: 1})
// not enough
{
fs.User.Group.MaxStorage = 10
a.Error(HookValidateCapacityWithoutIncrease(ctx, fs))
a.EqualValues(10, fs.User.Storage)
}
// enough
{
fs.User.Group.MaxStorage = 11
a.NoError(HookValidateCapacityWithoutIncrease(ctx, fs))
a.EqualValues(10, fs.User.Storage)
}
}

View File

@@ -4,6 +4,9 @@ import (
"context"
"fmt"
"strconv"
"sync"
"runtime"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
@@ -35,18 +38,53 @@ func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentR
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h})
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
res, err := fs.Handler.Thumb(ctx, fs.FileTarget[0].SourceName)
if err == nil && conf.SystemConfig.Mode == "master" {
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
}
// 本地存储策略出错时重新生成缩略图
if err != nil && fs.Policy.Type == "local" {
fs.GenerateThumbnail(ctx, &fs.FileTarget[0])
res, err = fs.Handler.Thumb(ctx, fs.FileTarget[0].SourceName)
}
if err == nil && conf.SystemConfig.Mode == "master" {
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
}
return res, err
}
// thumbPool 要使用的任务池
var thumbPool *Pool
var once sync.Once
// Pool 带有最大配额的任务池
type Pool struct {
// 容量
worker chan int
}
// Init 初始化任务池
func getThumbWorker() *Pool {
once.Do(func() {
maxWorker := conf.ThumbConfig.MaxTaskCount
if maxWorker <= 0 {
maxWorker = runtime.GOMAXPROCS(0)
}
thumbPool = &Pool{
worker: make(chan int, maxWorker),
}
util.Log().Debug("初始化Thumb任务队列WorkerNum = %d", maxWorker)
})
return thumbPool
}
func (pool *Pool) addWorker() {
pool.worker <- 1
util.Log().Debug("Thumb任务队列addWorker")
}
func (pool *Pool) releaseWorker() {
util.Log().Debug("Thumb任务队列releaseWorker")
<-pool.worker
}
// GenerateThumbnail 尝试为本地策略文件生成缩略图并获取图像原始大小
// TODO 失败时,如果之前还有图像信息,则清除
func (fs *FileSystem) GenerateThumbnail(ctx context.Context, file *model.File) {
@@ -65,6 +103,8 @@ func (fs *FileSystem) GenerateThumbnail(ctx context.Context, file *model.File) {
return
}
defer source.Close()
getThumbWorker().addWorker()
defer getThumbWorker().releaseWorker()
image, err := thumb.NewThumbFromFile(source, file.Name)
if err != nil {
@@ -79,6 +119,12 @@ func (fs *FileSystem) GenerateThumbnail(ctx context.Context, file *model.File) {
image.GetThumb(fs.GenerateThumbnailSize(w, h))
// 保存到文件
err = image.Save(util.RelativePath(file.SourceName + conf.ThumbConfig.FileSuffix))
image = nil
if conf.ThumbConfig.GCAfterGen {
util.Log().Debug("GenerateThumbnail runtime.GC")
runtime.GC()
}
if err != nil {
util.Log().Warning("无法保存缩略图:%s", err)
return

View File

@@ -38,3 +38,12 @@ func TestFileSystem_GetThumb(t *testing.T) {
asserts.EqualValues(50, res.MaxAge)
}
}
func TestFileSystem_ThumbWorker(t *testing.T) {
asserts := assert.New(t)
asserts.NotPanics(func() {
getThumbWorker().addWorker()
getThumbWorker().releaseWorker()
})
}

View File

@@ -356,47 +356,6 @@ func TestFileSystem_Delete(t *testing.T) {
}}
ctx := context.Background()
//全部未成功
{
// 列出要删除的目录
mock.ExpectQuery("SELECT(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3),
)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).
AddRow(4, "1.txt", "1.txt", 2, 1),
)
// 查询顶级的文件
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "1.txt", "1.txt", 603, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
// 查找软连接
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
// 查询上传策略
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(603, "local"))
// 删除文件记录
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)files").
WillReturnResult(sqlmock.NewResult(0, 3))
mock.ExpectCommit()
// 删除对应分享
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)shares").
WillReturnResult(sqlmock.NewResult(0, 3))
mock.ExpectCommit()
err := fs.Delete(ctx, []uint{1}, []uint{1}, false)
asserts.Error(err)
asserts.Equal(203, err.(serializer.AppError).Code)
asserts.Equal(uint64(3), fs.User.Storage)
asserts.NoError(mock.ExpectationsWereMet())
}
//全部未成功,强制
{
fs.CleanTargets()

View File

@@ -228,12 +228,14 @@ func (fs *FileSystem) UploadFromStream(ctx context.Context, src io.ReadCloser, d
}
// UploadFromPath 将本机已有文件上传到用户的文件系统
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string) error {
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, resetPolicy bool) error {
// 重设存储策略
fs.Policy = &fs.User.Policy
err := fs.DispatchHandler()
if err != nil {
return err
if resetPolicy {
fs.Policy = &fs.User.Policy
err := fs.DispatchHandler()
if err != nil {
return err
}
}
file, err := os.Open(util.RelativePath(src))

View File

@@ -226,13 +226,13 @@ func TestFileSystem_UploadFromPath(t *testing.T) {
// 文件不存在
{
err := fs.UploadFromPath(ctx, "test/not_exist", "/")
err := fs.UploadFromPath(ctx, "test/not_exist", "/", true)
asserts.Error(err)
}
// 文存在,上传失败
{
err := fs.UploadFromPath(ctx, "tests/test.zip", "/")
err := fs.UploadFromPath(ctx, "tests/test.zip", "/", true)
asserts.Error(err)
}
}

View File

@@ -0,0 +1,43 @@
package controllermock
import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/mock"
)
type SlaveControllerMock struct {
mock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetOneDriveToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}

Some files were not shown because too many files have changed in this diff Show More