mirror of
https://github.com/cloudreve/cloudreve.git
synced 2026-03-10 23:47:01 +00:00
Compare commits
217 Commits
3.0.0-beta
...
3.4.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c31c77a089 | ||
|
|
6b15cae0b5 | ||
|
|
84d81f201f | ||
|
|
af4d9767c2 | ||
|
|
45597adcd3 | ||
|
|
762f0f9c68 | ||
|
|
c5074df1c7 | ||
|
|
7ea72cf364 | ||
|
|
4eb7525c51 | ||
|
|
3948ee7f3a | ||
|
|
865a801fa8 | ||
|
|
05941616df | ||
|
|
51b1e5b854 | ||
|
|
4dbe867020 | ||
|
|
8c8ad3e149 | ||
|
|
fce38209bc | ||
|
|
700e13384e | ||
|
|
7fd984f95d | ||
|
|
9fc08292a0 | ||
|
|
8c5445a26d | ||
|
|
96b84bb5e5 | ||
|
|
9056ef9171 | ||
|
|
532bff820a | ||
|
|
fcd9eddc54 | ||
|
|
6c9967b120 | ||
|
|
416f4c1dd2 | ||
|
|
f0089045d7 | ||
|
|
4b88eacb6a | ||
|
|
54ed7e43ca | ||
|
|
4d7b8685b9 | ||
|
|
eeee43d569 | ||
|
|
3064ed60f3 | ||
|
|
e41ec9defa | ||
|
|
eaa0f6be91 | ||
|
|
5db476634a | ||
|
|
1f06ee3af6 | ||
|
|
22bbfe7da1 | ||
|
|
f1dc4c4758 | ||
|
|
5f861b963a | ||
|
|
056de22edb | ||
|
|
a3b4a22dbc | ||
|
|
9ff1b47646 | ||
|
|
65c4367689 | ||
|
|
db7489fb61 | ||
|
|
622b928a90 | ||
|
|
c0158ea224 | ||
|
|
e6959a5026 | ||
|
|
9d64bdd9f6 | ||
|
|
c85c2da523 | ||
|
|
8659bdcf77 | ||
|
|
641fe352da | ||
|
|
96712fb066 | ||
|
|
a1252c810b | ||
|
|
e781185ad2 | ||
|
|
95802efcec | ||
|
|
233648b956 | ||
|
|
53acadf098 | ||
|
|
c0f7214cdb | ||
|
|
ccaefdab33 | ||
|
|
6efd8e8183 | ||
|
|
144b534486 | ||
|
|
e160154d3b | ||
|
|
2381eca230 | ||
|
|
adde486a30 | ||
|
|
a9c0d6ed17 | ||
|
|
595f4a1350 | ||
|
|
a5f80a4431 | ||
|
|
6fb419d998 | ||
|
|
3f0f33b4fc | ||
|
|
052e6be393 | ||
|
|
a4b0ad81e9 | ||
|
|
8431906b94 | ||
|
|
40476953aa | ||
|
|
270f617b9d | ||
|
|
170f2279c1 | ||
|
|
d1377262e3 | ||
|
|
c9acf7e64e | ||
|
|
4e2f243436 | ||
|
|
a54acd71c2 | ||
|
|
fec2fe14f8 | ||
|
|
1f1bc056e3 | ||
|
|
e44ec0e6bf | ||
|
|
a93b964d8b | ||
|
|
d9cff24c75 | ||
|
|
e2488841b4 | ||
|
|
a276be4098 | ||
|
|
4cf6c81534 | ||
|
|
5a66af3105 | ||
|
|
fc5c67cc20 | ||
|
|
5e226efea1 | ||
|
|
c949d47161 | ||
|
|
e699287ffd | ||
|
|
9c78515c72 | ||
|
|
3b22b4fd25 | ||
|
|
08d998b41e | ||
|
|
488e62f762 | ||
|
|
f35ad3fe0a | ||
|
|
61e6d9b591 | ||
|
|
feb1134a7c | ||
|
|
9f2f14cacf | ||
|
|
055ed0e075 | ||
|
|
c87109c8b1 | ||
|
|
8057c4b8bc | ||
|
|
5ab93a6e0d | ||
|
|
5d406f1c6a | ||
|
|
5b44606276 | ||
|
|
bd2bdf253b | ||
|
|
0cfa61e264 | ||
|
|
f7c8039116 | ||
|
|
6486e8799b | ||
|
|
7279be2924 | ||
|
|
33f8419999 | ||
|
|
a5805b022a | ||
|
|
ae89b402f6 | ||
|
|
0d210e87b3 | ||
|
|
f0a68236a8 | ||
|
|
c6110e9e75 | ||
|
|
d97bc26042 | ||
|
|
11c218eb94 | ||
|
|
79b8784934 | ||
|
|
59d50b1b98 | ||
|
|
746aa3e8ef | ||
|
|
95f318e069 | ||
|
|
77394313aa | ||
|
|
41eb84a221 | ||
|
|
40414fe6ae | ||
|
|
7df09537e0 | ||
|
|
f478c38307 | ||
|
|
bfd2340732 | ||
|
|
dd50ef1c25 | ||
|
|
27bf8ca9b2 | ||
|
|
c71a2c5b64 | ||
|
|
a7ba357cb8 | ||
|
|
14f5982b47 | ||
|
|
e607311268 | ||
|
|
acc5d53bab | ||
|
|
aa3e8913ab | ||
|
|
ee0f8e964d | ||
|
|
60745ac8ba | ||
|
|
bfb5b34edc | ||
|
|
a5000c0621 | ||
|
|
e038350cf0 | ||
|
|
5af3c4e244 | ||
|
|
7ed14c4d81 | ||
|
|
869c0006c5 | ||
|
|
4c458df666 | ||
|
|
ed684420a2 | ||
|
|
2076d56f0f | ||
|
|
280308bc05 | ||
|
|
1172765c58 | ||
|
|
58856612e2 | ||
|
|
ee0f224cbb | ||
|
|
e8a6df9a86 | ||
|
|
b02d27ca0a | ||
|
|
51f66eb06b | ||
|
|
0df8a9ba65 | ||
|
|
79daf92896 | ||
|
|
8a2be58ef3 | ||
|
|
ce7784090f | ||
|
|
dfb663a6e0 | ||
|
|
cfaf20926f | ||
|
|
305497e7cb | ||
|
|
52c2422be9 | ||
|
|
1afc750dae | ||
|
|
960c886496 | ||
|
|
5d579cdadc | ||
|
|
d5fc5745b4 | ||
|
|
a732025d5a | ||
|
|
030fd4ac57 | ||
|
|
9eeb4b6d19 | ||
|
|
36e5b31f73 | ||
|
|
26d4d34837 | ||
|
|
8c547a05fd | ||
|
|
f7311f906b | ||
|
|
6006ff4d22 | ||
|
|
034ed956a3 | ||
|
|
700c5795f5 | ||
|
|
8c3287d380 | ||
|
|
304e7b502c | ||
|
|
514e069113 | ||
|
|
7b571499a7 | ||
|
|
8b68d46bdf | ||
|
|
ab3b59e63d | ||
|
|
9910f8d732 | ||
|
|
99033d61c6 | ||
|
|
e9f3a55eb8 | ||
|
|
3cc9940924 | ||
|
|
bcdf94fd93 | ||
|
|
3c09ad7c02 | ||
|
|
7be0366b1f | ||
|
|
0575b0aa92 | ||
|
|
2e342806b6 | ||
|
|
cf9dc1c24f | ||
|
|
e58fb82463 | ||
|
|
fa900b166a | ||
|
|
2e43f8ed5b | ||
|
|
554493dea4 | ||
|
|
816b537787 | ||
|
|
e07b09186d | ||
|
|
8c7d075484 | ||
|
|
46743f3c1e | ||
|
|
1a1543f190 | ||
|
|
4aef12bf7e | ||
|
|
691c9aeb7d | ||
|
|
6285e45e34 | ||
|
|
f594d0ab83 | ||
|
|
25d1735c1d | ||
|
|
c4c174f560 | ||
|
|
175c4d781f | ||
|
|
87fde687eb | ||
|
|
65cf0f57aa | ||
|
|
0eb04ed0ea | ||
|
|
96983ddc70 | ||
|
|
b98e5efb83 | ||
|
|
ff2dae80f0 | ||
|
|
32c0232105 | ||
|
|
a05a3de0e1 |
1
.github/FUNDING.yml
vendored
Normal file
1
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1 @@
|
||||
custom: ["https://cloudreve.org/buy.php"]
|
||||
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. iOS]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Smartphone (please complete the following information):**
|
||||
- Device: [e.g. iPhone6]
|
||||
- OS: [e.g. iOS8.1]
|
||||
- Browser [e.g. stock browser, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
61
.github/workflows/build.yml
vendored
Normal file
61
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
name: Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
clean: false
|
||||
submodules: 'recursive'
|
||||
- run: |
|
||||
git fetch --prune --unshallow --tags
|
||||
|
||||
- name: Get dependencies and build
|
||||
run: |
|
||||
go get github.com/rakyll/statik
|
||||
export PATH=$PATH:~/go/bin/
|
||||
statik -src=models -f
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install gcc-mingw-w64-x86-64
|
||||
sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross
|
||||
sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross
|
||||
chmod +x ./build.sh
|
||||
./build.sh -r b
|
||||
|
||||
- name: Upload binary files (windows_amd64)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_windows_amd64
|
||||
path: release/cloudreve*windows_amd64.*
|
||||
|
||||
- name: Upload binary files (linux_amd64)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_amd64
|
||||
path: release/cloudreve*linux_amd64.*
|
||||
|
||||
- name: Upload binary files (linux_arm)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm
|
||||
path: release/cloudreve*linux_arm.*
|
||||
|
||||
- name: Upload binary files (linux_arm64)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm64
|
||||
path: release/cloudreve*linux_arm64.*
|
||||
47
.github/workflows/test.yml
vendored
Normal file
47
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
push:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-18.04
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: 'recursive'
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get github.com/rakyll/statik
|
||||
export PATH=$PATH:~/go/bin/
|
||||
statik -src=models -f
|
||||
|
||||
- name: Test
|
||||
run: go test -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Upload binary files (linux_arm)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm
|
||||
path: release/cloudreve*linux_arm.*
|
||||
|
||||
- name: Upload binary files (linux_arm64)
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: cloudreve_linux_arm64
|
||||
path: release/cloudreve*linux_arm64.*
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
# Binaries for programs and plugins
|
||||
cloudreve
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
@@ -26,3 +27,4 @@ version.lock
|
||||
*.ini
|
||||
conf/conf.ini
|
||||
/statik/
|
||||
.vscode/
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
language: go
|
||||
go:
|
||||
- 1.13.x
|
||||
node_js: "12.16.3"
|
||||
git:
|
||||
depth: 1
|
||||
install:
|
||||
@@ -15,6 +16,7 @@ before_deploy:
|
||||
- sudo apt-get update
|
||||
- sudo apt-get -y install gcc-mingw-w64-x86-64
|
||||
- sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross
|
||||
- sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross
|
||||
- chmod +x ./build.sh
|
||||
- ./build.sh -r b
|
||||
deploy:
|
||||
|
||||
72
Dockerfile
Normal file
72
Dockerfile
Normal file
@@ -0,0 +1,72 @@
|
||||
# build frontend
|
||||
FROM node:lts-buster AS fe-builder
|
||||
|
||||
COPY ./assets /assets
|
||||
|
||||
WORKDIR /assets
|
||||
|
||||
# If encountered problems like JavaScript heap out of memory, please uncomment the following options
|
||||
ENV NODE_OPTIONS --max_old_space_size=4096
|
||||
|
||||
# yarn repo connection is unstable, adjust the network timeout to 10 min.
|
||||
RUN set -ex \
|
||||
&& yarn install --network-timeout 600000 \
|
||||
&& yarn run build
|
||||
|
||||
# build backend
|
||||
FROM golang:1.15.1-alpine3.12 AS be-builder
|
||||
|
||||
ENV GO111MODULE on
|
||||
|
||||
COPY . /go/src/github.com/cloudreve/Cloudreve/v3
|
||||
COPY --from=fe-builder /assets/build/ /go/src/github.com/cloudreve/Cloudreve/v3/assets/build/
|
||||
|
||||
WORKDIR /go/src/github.com/cloudreve/Cloudreve/v3
|
||||
|
||||
RUN set -ex \
|
||||
&& apk upgrade \
|
||||
&& apk add gcc libc-dev git \
|
||||
&& export COMMIT_SHA=$(git rev-parse --short HEAD) \
|
||||
&& export VERSION=$(git describe --tags) \
|
||||
&& (cd && go get github.com/rakyll/statik) \
|
||||
&& statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico -f \
|
||||
&& go install -ldflags "-X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=${VERSION}' \
|
||||
-X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=${COMMIT_SHA}'\
|
||||
-w -s"
|
||||
|
||||
# build final image
|
||||
FROM alpine:3.12 AS dist
|
||||
|
||||
LABEL maintainer="mritd <mritd@linux.com>"
|
||||
|
||||
# we use the Asia/Shanghai timezone by default, you can be modified
|
||||
# by `docker build --build-arg=TZ=Other_Timezone ...`
|
||||
ARG TZ="Asia/Shanghai"
|
||||
|
||||
ENV TZ ${TZ}
|
||||
|
||||
COPY --from=be-builder /go/bin/Cloudreve /cloudreve/cloudreve
|
||||
COPY docker-bootstrap.sh /cloudreve/bootstrap.sh
|
||||
|
||||
RUN apk upgrade \
|
||||
&& apk add bash tzdata aria2 \
|
||||
&& ln -s /cloudreve/cloudreve /usr/bin/cloudreve \
|
||||
&& ln -sf /usr/share/zoneinfo/${TZ} /etc/localtime \
|
||||
&& echo ${TZ} > /etc/timezone \
|
||||
&& rm -rf /var/cache/apk/* \
|
||||
&& mkdir /etc/cloudreve \
|
||||
&& ln -s /etc/cloudreve/cloureve.db /cloudreve/cloudreve.db \
|
||||
&& ln -s /etc/cloudreve/conf.ini /cloudreve/conf.ini
|
||||
|
||||
# cloudreve use tcp 5212 port by default
|
||||
EXPOSE 5212/tcp
|
||||
|
||||
# cloudreve stores all files(including executable file) in the `/cloudreve`
|
||||
# directory by default; users should mount the configfile to the `/etc/cloudreve`
|
||||
# directory by themselves for persistence considerations, and the data storage
|
||||
# directory recommends using `/data` directory.
|
||||
VOLUME /etc/cloudreve
|
||||
|
||||
VOLUME /data
|
||||
|
||||
ENTRYPOINT ["sh", "/cloudreve/bootstrap.sh"]
|
||||
14
README.md
14
README.md
@@ -9,8 +9,8 @@
|
||||
<h4 align="center">支持多家云存储驱动的公有云文件系统.</h4>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://travis-ci.org/github/cloudreve/Cloudreve">
|
||||
<img src="https://img.shields.io/travis/cloudreve/Cloudreve?style=flat-square"
|
||||
<a href="https://travis-ci.com/github/cloudreve/Cloudreve/">
|
||||
<img src="https://img.shields.io/travis/com/cloudreve/Cloudreve?style=flat-square"
|
||||
alt="travis">
|
||||
</a>
|
||||
<a href="https://codecov.io/gh/cloudreve/Cloudreve"><img src="https://img.shields.io/codecov/c/github/cloudreve/Cloudreve?style=flat-square"></a>
|
||||
@@ -37,7 +37,7 @@
|
||||
|
||||
* :cloud: 支持本机、从机、七牛、阿里云 OSS、腾讯云 COS、又拍云、OneDrive (包括世纪互联版) 作为存储端
|
||||
* :outbox_tray: 上传/下载 支持客户端直传,支持下载限速
|
||||
* 💾 可对接 Aria2 离线下载
|
||||
* 💾 可对接 Aria2 离线下载,可使用多个从机机点分担下载任务
|
||||
* 📚 在线 压缩/解压缩、多文件打包下载
|
||||
* 💻 覆盖全部存储策略的 WebDAV 协议支持
|
||||
* :zap: 拖拽上传、目录上传、流式上传处理
|
||||
@@ -55,7 +55,7 @@
|
||||
|
||||
```shell
|
||||
# 解压程序包
|
||||
tar - czvf cloudreve_VERSION_OS_ARCH.tar.gz
|
||||
tar -zxvf cloudreve_VERSION_OS_ARCH.tar.gz
|
||||
|
||||
# 赋予执行权限
|
||||
chmod +x ./cloudreve
|
||||
@@ -80,7 +80,7 @@ git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git
|
||||
|
||||
```shell
|
||||
# 进入前端子模块
|
||||
cd asserts
|
||||
cd assets
|
||||
# 安装依赖
|
||||
yarn install
|
||||
# 开始构建
|
||||
@@ -108,7 +108,7 @@ export COMMIT_SHA=$(git rev-parse --short HEAD)
|
||||
export VERSION=$(git describe --tags)
|
||||
|
||||
# 开始编译
|
||||
go build -a -o cloudreve -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$VERSION' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
```
|
||||
|
||||
你也可以使用项目根目录下的`build.sh`快速开始构建:
|
||||
@@ -132,4 +132,4 @@ GPL V3
|
||||
|
||||
---
|
||||
> GitHub [@HFO4](https://github.com/HFO4) ·
|
||||
> Twitter [@abslant00](https://twitter.com/abslant00)
|
||||
> Twitter [@abslant00](https://twitter.com/abslant00)
|
||||
|
||||
2
assets
2
assets
Submodule assets updated: f297f331f9...88c1133306
@@ -1,8 +1,13 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
// InitApplication 初始化应用常量
|
||||
@@ -10,7 +15,7 @@ func InitApplication() {
|
||||
fmt.Print(`
|
||||
___ _ _
|
||||
/ __\ | ___ _ _ __| |_ __ _____ _____
|
||||
/ / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \
|
||||
/ / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \
|
||||
/ /___| | (_) | |_| | (_| | | | __/\ V / __/
|
||||
\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___|
|
||||
|
||||
@@ -18,4 +23,36 @@ func InitApplication() {
|
||||
================================================
|
||||
|
||||
`)
|
||||
go CheckUpdate()
|
||||
}
|
||||
|
||||
type GitHubRelease struct {
|
||||
URL string `json:"html_url"`
|
||||
Name string `json:"name"`
|
||||
Tag string `json:"tag_name"`
|
||||
}
|
||||
|
||||
// CheckUpdate 检查更新
|
||||
func CheckUpdate() {
|
||||
client := request.NewClient()
|
||||
res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
|
||||
if err != nil {
|
||||
util.Log().Warning("更新检查失败, %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var list []GitHubRelease
|
||||
if err := json.Unmarshal([]byte(res), &list); err != nil {
|
||||
util.Log().Warning("更新检查失败, %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
present, err1 := version.NewVersion(conf.BackendVersion)
|
||||
latest, err2 := version.NewVersion(list[0].Tag)
|
||||
if err1 == nil && err2 == nil && latest.GreaterThan(present) {
|
||||
util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/crontab"
|
||||
"github.com/HFO4/cloudreve/pkg/email"
|
||||
"github.com/HFO4/cloudreve/pkg/task"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -20,14 +23,92 @@ func Init(path string) {
|
||||
if !conf.SystemConfig.Debug {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
cache.Init()
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
model.Init()
|
||||
task.Init()
|
||||
aria2.Init(false)
|
||||
email.Init()
|
||||
crontab.Init()
|
||||
InitStatic()
|
||||
|
||||
dependencies := []struct {
|
||||
mode string
|
||||
factory func()
|
||||
}{
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
scripts.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
cache.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
model.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
task.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
cluster.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
aria2.Init(false, cluster.Default, mq.GlobalMQ)
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
email.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
crontab.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"master",
|
||||
func() {
|
||||
InitStatic()
|
||||
},
|
||||
},
|
||||
{
|
||||
"slave",
|
||||
func() {
|
||||
cluster.InitController()
|
||||
},
|
||||
},
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
auth.Init()
|
||||
},
|
||||
},
|
||||
}
|
||||
auth.Init()
|
||||
|
||||
for _, dependency := range dependencies {
|
||||
switch dependency.mode {
|
||||
case "master":
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
dependency.factory()
|
||||
}
|
||||
case "slave":
|
||||
if conf.SystemConfig.Mode == "slave" {
|
||||
dependency.factory()
|
||||
}
|
||||
default:
|
||||
dependency.factory()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
18
bootstrap/script.go
Normal file
18
bootstrap/script.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
func RunScript(name string) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if err := invoker.RunDBScript(name, ctx); err != nil {
|
||||
util.Log().Error("数据库脚本执行失败: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("数据库脚本 [%s] 执行完毕", name)
|
||||
}
|
||||
@@ -1,17 +1,30 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
_ "github.com/HFO4/cloudreve/statik"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"path"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
_ "github.com/cloudreve/Cloudreve/v3/statik"
|
||||
"github.com/gin-contrib/static"
|
||||
"github.com/rakyll/statik/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const StaticFolder = "statics"
|
||||
|
||||
type GinFS struct {
|
||||
FS http.FileSystem
|
||||
}
|
||||
|
||||
type staticVersion struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StaticFS 内置静态文件资源
|
||||
var StaticFS static.ServeFileSystem
|
||||
|
||||
@@ -34,9 +47,44 @@ func (b *GinFS) Exists(prefix string, filepath string) bool {
|
||||
func InitStatic() {
|
||||
var err error
|
||||
|
||||
if util.Exists(util.RelativePath("statics")) {
|
||||
if util.Exists(util.RelativePath(StaticFolder)) {
|
||||
util.Log().Info("检测到 statics 目录存在,将使用此目录下的静态资源文件")
|
||||
StaticFS = static.LocalFile(util.RelativePath("statics"), false)
|
||||
|
||||
// 检查静态资源的版本
|
||||
f, err := StaticFS.Open("version.json")
|
||||
if err != nil {
|
||||
util.Log().Warning("静态资源版本标识文件不存在,请重新构建或删除 statics 目录")
|
||||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法读取静态资源文件版本,请重新构建或删除 statics 目录")
|
||||
return
|
||||
}
|
||||
|
||||
var v staticVersion
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
util.Log().Warning("无法解析静态资源文件版本, %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
staticName := "cloudreve-frontend"
|
||||
if conf.IsPro == "true" {
|
||||
staticName += "-pro"
|
||||
}
|
||||
|
||||
if v.Name != staticName {
|
||||
util.Log().Warning("静态资源版本不匹配,请重新构建或删除 statics 目录")
|
||||
return
|
||||
}
|
||||
|
||||
if v.Version != conf.RequiredStaticVersion {
|
||||
util.Log().Warning("静态资源版本不匹配 [当前 %s, 需要: %s],请重新构建或删除 statics 目录", v.Version, conf.RequiredStaticVersion)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
StaticFS = &GinFS{}
|
||||
StaticFS.(*GinFS).FS, err = fs.New()
|
||||
@@ -46,3 +94,65 @@ func InitStatic() {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Eject 抽离内置静态资源
|
||||
func Eject() {
|
||||
staticFS, err := fs.New()
|
||||
if err != nil {
|
||||
util.Log().Panic("无法初始化静态资源, %s", err)
|
||||
}
|
||||
|
||||
root, err := staticFS.Open("/")
|
||||
if err != nil {
|
||||
util.Log().Panic("根目录不存在, %s", err)
|
||||
}
|
||||
|
||||
var walk func(relPath string, object http.File)
|
||||
walk = func(relPath string, object http.File) {
|
||||
stat, err := object.Stat()
|
||||
if err != nil {
|
||||
util.Log().Error("无法获取[%s]的信息, %s, 跳过...", relPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
// 写入文件
|
||||
out, err := util.CreatNestedFile(util.RelativePath(StaticFolder + relPath))
|
||||
defer out.Close()
|
||||
|
||||
if err != nil {
|
||||
util.Log().Error("无法创建文件[%s], %s, 跳过...", relPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("导出 [%s]...", relPath)
|
||||
if _, err := io.Copy(out, object); err != nil {
|
||||
util.Log().Error("无法写入文件[%s], %s, 跳过...", relPath, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 列出目录
|
||||
objects, err := object.Readdir(0)
|
||||
if err != nil {
|
||||
util.Log().Error("无法步入子目录[%s], %s, 跳过...", relPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 递归遍历子目录
|
||||
for _, newObject := range objects {
|
||||
newPath := path.Join(relPath, newObject.Name())
|
||||
newRoot, err := staticFS.Open(newPath)
|
||||
if err != nil {
|
||||
util.Log().Error("无法打开对象[%s], %s, 跳过...", newPath, err)
|
||||
continue
|
||||
}
|
||||
walk(newPath, newRoot)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
util.Log().Info("开始导出内置静态资源...")
|
||||
walk("/", root)
|
||||
util.Log().Info("内置静态资源导出完成")
|
||||
}
|
||||
|
||||
17
build.sh
17
build.sh
@@ -34,12 +34,12 @@ buildAssets () {
|
||||
fi
|
||||
|
||||
cd $REPO
|
||||
statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico -f
|
||||
statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico,*.ttf -f
|
||||
}
|
||||
|
||||
buildBinary () {
|
||||
cd $REPO
|
||||
go build -a -o cloudreve -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$VERSION' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
}
|
||||
|
||||
_build() {
|
||||
@@ -55,8 +55,13 @@ _build() {
|
||||
export CC=$gcc
|
||||
export CGO_ENABLED=1
|
||||
|
||||
out="release/cloudreve_${VERSION}_${os}_${arch}"
|
||||
go build -a -o "${out}" -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$VERSION' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
if [ -n "$VERSION" ]; then
|
||||
out="release/cloudreve_${VERSION}_${os}_${arch}"
|
||||
else
|
||||
out="release/cloudreve_${COMMIT_SHA}_${os}_${arch}"
|
||||
fi
|
||||
|
||||
go build -a -o "${out}" -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
|
||||
|
||||
if [ "$os" = "windows" ]; then
|
||||
mv $out release/cloudreve.exe
|
||||
@@ -72,7 +77,7 @@ _build() {
|
||||
release(){
|
||||
cd $REPO
|
||||
## List of architectures and OS to test coss compilation.
|
||||
SUPPORTED_OSARCH="linux/amd64/gcc linux/arm/arm-linux-gnueabihf-gcc windows/amd64/x86_64-w64-mingw32-gcc"
|
||||
SUPPORTED_OSARCH="linux/amd64/gcc linux/arm/arm-linux-gnueabihf-gcc windows/amd64/x86_64-w64-mingw32-gcc linux/arm64/aarch64-linux-gnu-gcc"
|
||||
|
||||
echo "Release builds for OS/Arch/CC: ${SUPPORTED_OSARCH}"
|
||||
for each_osarch in ${SUPPORTED_OSARCH}; do
|
||||
@@ -125,4 +130,4 @@ fi
|
||||
|
||||
if [ "$RELEASE" = "true" ]; then
|
||||
release
|
||||
fi
|
||||
fi
|
||||
|
||||
15
docker-bootstrap.sh
Normal file
15
docker-bootstrap.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/sh
|
||||
GREEN='\033[0;32m'
|
||||
RESET='\033[0m'
|
||||
if [ ! -f /etc/cloudreve/aria2c.conf ]; then
|
||||
echo -e "[${GREEN}aria2c${RESET}] aria2c config not found. Generating..."
|
||||
secret=$(tr -dc A-Za-z0-9 </dev/urandom | head -c 13)
|
||||
echo -e "[${GREEN}aria2c${RESET}] Generated port: 6800, secret: $secret"
|
||||
cat <<EOF > /etc/cloudreve/aria2c.conf
|
||||
enable-rpc=true
|
||||
rpc-listen-port=6800
|
||||
rpc-secret=$secret
|
||||
EOF
|
||||
fi
|
||||
aria2c --conf-path /etc/cloudreve/aria2c.conf -D
|
||||
cloudreve
|
||||
16
go.mod
16
go.mod
@@ -1,43 +1,45 @@
|
||||
module github.com/HFO4/cloudreve
|
||||
module github.com/cloudreve/Cloudreve/v3
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.3.3
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible
|
||||
github.com/aws/aws-sdk-go v1.31.5
|
||||
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect
|
||||
github.com/duo-labs/webauthn v0.0.0-20191119193225-4bf9a0f776d4
|
||||
github.com/fatih/color v1.7.0
|
||||
github.com/gin-contrib/cors v1.3.0
|
||||
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8
|
||||
github.com/gin-contrib/sessions v0.0.1
|
||||
github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2
|
||||
github.com/gin-gonic/gin v1.5.0
|
||||
github.com/go-ini/ini v1.50.0
|
||||
github.com/go-mail/mail v2.3.1+incompatible
|
||||
github.com/gofrs/uuid v3.2.0+incompatible
|
||||
github.com/gofrs/uuid v4.0.0+incompatible
|
||||
github.com/gomodule/redigo v2.0.0+incompatible
|
||||
github.com/google/go-querystring v1.0.0
|
||||
github.com/gorilla/websocket v1.4.1
|
||||
github.com/hashicorp/go-version v1.2.0
|
||||
github.com/jinzhu/gorm v1.9.11
|
||||
github.com/juju/ratelimit v1.0.1
|
||||
github.com/mattn/go-colorable v0.1.4 // indirect
|
||||
github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/pkg/errors v0.8.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.2.0
|
||||
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371
|
||||
github.com/qiniu/api.v7/v7 v7.4.0
|
||||
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1
|
||||
github.com/rakyll/statik v0.1.7
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/smartwalle/alipay/v3 v3.0.13
|
||||
github.com/smartystreets/goconvey v1.6.4 // indirect
|
||||
github.com/speps/go-hashids v2.0.0+incompatible
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac
|
||||
github.com/upyun/go-sdk v2.1.0+incompatible
|
||||
golang.org/x/text v0.3.2
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
|
||||
golang.org/x/text v0.3.6
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
||||
gopkg.in/go-playground/validator.v9 v9.29.1
|
||||
gopkg.in/ini.v1 v1.51.0 // indirect
|
||||
|
||||
43
go.sum
43
go.sum
@@ -12,9 +12,12 @@ github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7I
|
||||
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.0.0/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible h1:A3oZlWPD/Poa19FvNbw+Zu4yKAurDBTjlRDilYGBiS4=
|
||||
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8=
|
||||
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
|
||||
github.com/aws/aws-sdk-go v1.31.5 h1:DFA7BzTydO4etqsTja+x7UfkOKQUv1xzEluLvNk81L0=
|
||||
github.com/aws/aws-sdk-go v1.31.5/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0=
|
||||
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f h1:ZNv7On9kyUzm7fvRZumSyy/IUiSC7AzL0I1jKKtwooA=
|
||||
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
@@ -48,15 +51,15 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/gin-contrib/cors v1.3.0 h1:PolezCc89peu+NgkIWt9OB01Kbzt6IP0J/JvkG6xxlg=
|
||||
github.com/gin-contrib/cors v1.3.0/go.mod h1:artPvLlhkF7oG06nK8v3U8TNz6IeX+w1uzCSEId5/Vc=
|
||||
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8 h1:/DnKeA2+K83hkii3nqMJ5koknI+/qlojjxgcSyiAyJw=
|
||||
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8/go.mod h1:M+xPw/lXk+uAU4iYVnwPZs0iIpR/KwSQSXcJabN+gPs=
|
||||
github.com/gin-contrib/sessions v0.0.1 h1:xr9V/u3ERQnkugKSY/u36cNnC4US4bHJpdxcB6eIZLk=
|
||||
github.com/gin-contrib/sessions v0.0.1/go.mod h1:iziXm/6pvTtf7og1uxT499sel4h3S9DfwsrhNZ+REXM=
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3 h1:t8FVkw33L+wilf2QiWkw0UV77qRpcH/JHPKGpKa2E8g=
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2 h1:xLG16iua01X7Gzms9045s2Y2niNpvSY/Zb1oBwgNYZY=
|
||||
github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2/go.mod h1:VhW/Ch/3FhimwZb8Oj+qJmdMmoB8r7lmJ5auRjm50oQ=
|
||||
github.com/gin-gonic/gin v1.4.0 h1:3tMoCCfM7ppqsR0ptz/wi1impNpT7/9wQtMZ8lr1mCQ=
|
||||
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
|
||||
github.com/gin-gonic/gin v1.5.0 h1:fi+bqFAx/oLK54somfCtEZs9HeH1LHVoEPUgARpTqyc=
|
||||
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
|
||||
@@ -71,11 +74,12 @@ github.com/go-playground/locales v0.12.1 h1:2FITxuFt/xuCNP1Acdhv62OzaCiviiE4kotf
|
||||
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
|
||||
github.com/go-playground/universal-translator v0.16.0 h1:X++omBR/4cE2MNg91AoC3rmGrCjJ8eAeUP/K/EKx4DM=
|
||||
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
|
||||
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
|
||||
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE=
|
||||
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
@@ -84,7 +88,6 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
@@ -116,6 +119,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R
|
||||
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
||||
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E=
|
||||
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/jinzhu/gorm v1.9.11 h1:gaHGvE+UnWGlbWG4Y3FUwY1EcZ5n6S9WtqBA/uySMLE=
|
||||
@@ -124,6 +129,8 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M=
|
||||
github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
|
||||
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
@@ -152,7 +159,6 @@ github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
@@ -180,8 +186,9 @@ github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
|
||||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/otp v1.2.0 h1:/A3+Jn+cagqayeR3iHs/L62m5ue7710D35zl1zJ1kok=
|
||||
@@ -193,8 +200,6 @@ github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:
|
||||
github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
|
||||
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
|
||||
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371 h1:8VWtyY2IwjEQZSNT4Kyyct9zv9hoegD5GQhFr+TMdCI=
|
||||
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371/go.mod h1:9UFrQveqNm3ELF6HSvMtDR3KYpJ7Ib9s0WVmYhaUBlU=
|
||||
github.com/qiniu/api.v7/v7 v7.4.0 h1:9dZMVQifh31QGFLVaHls6akCaS2rlj3du8MnEFd7XjQ=
|
||||
github.com/qiniu/api.v7/v7 v7.4.0/go.mod h1:VE5oC5rkE1xul0u1S2N0b2Uxq9/6hZzhyqjgK25XDcM=
|
||||
github.com/quasoft/memstore v0.0.0-20180925164028-84a050167438 h1:jnz/4VenymvySjE+Ez511s0pqVzkUOmr1fwCVytNNWk=
|
||||
@@ -209,10 +214,6 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
|
||||
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/smartwalle/alipay/v3 v3.0.13 h1:f1Cdnxh6TfbaziLw0i/4h+f8tw9RJwG8y4xye7vTTgY=
|
||||
github.com/smartwalle/alipay/v3 v3.0.13/go.mod h1:cZUMCCnsux9YAxA0/f3PWUR+7wckWtE1BqxbVRtGij0=
|
||||
github.com/smartwalle/crypto4go v1.0.2 h1:9DUEOOsPhmp00438L4oBdcL8EZG1zumecft5bWj5phI=
|
||||
github.com/smartwalle/crypto4go v1.0.2/go.mod h1:LQ7vCZIb7BE5+MuMtJBuO8ORkkQ01m4DXDBWPzLbkMY=
|
||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
|
||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||
@@ -225,8 +226,9 @@ github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible h1:dqpmYaez7VBT7PCRBcBxkzlDOiTk7Td8ATiia1b1GuE=
|
||||
github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible/go.mod h1:0PfYow01SHPMhKY31xa+EFz2RStxIqj6JFAJS+IkCi4=
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac h1:PSBhZblOjdwH7SIVgcue+7OlnLHkM45KuScLZ+PiVbQ=
|
||||
@@ -242,12 +244,13 @@ go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc=
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75 h1:TbGuee8sSq15Iguxu4deQ7+Bqq/d2rsQejGcEtADAMQ=
|
||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
|
||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
|
||||
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -262,8 +265,9 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 h1:Ao/3l156eZf2AW5wK8a7/smtodRU+gha3+BeqJ69lRk=
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -280,7 +284,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e h1:D5TXcfTk7xF7hvieo4QErS3qqCB4teTffacDWr7CI+0=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -288,6 +291,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuArfcOvC4AoJmILihzhDg=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -299,7 +304,6 @@ golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3
|
||||
golang.org/x/tools v0.0.0-20190729092621-ff9f1409240a/go.mod h1:jcCCGcm9btYwXyDqrUWc6MKQKKGJCWEQ3AfLSRIbEuI=
|
||||
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
@@ -315,7 +319,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
gopkg.in/go-playground/validator.v8 v8.18.2 h1:lFB4DoMU6B626w8ny76MV7VX6W2VHct2GVOI3xgiMrQ=
|
||||
gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y=
|
||||
gopkg.in/go-playground/validator.v9 v9.29.1 h1:SvGtYmN60a5CVKTOzMSyfzWDeZRxRuGvRQyEAKbw1xc=
|
||||
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
|
||||
|
||||
50
main.go
50
main.go
@@ -2,22 +2,62 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/HFO4/cloudreve/bootstrap"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/HFO4/cloudreve/routers"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/bootstrap"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/routers"
|
||||
)
|
||||
|
||||
var confPath string
|
||||
var (
|
||||
isEject bool
|
||||
confPath string
|
||||
scriptName string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径")
|
||||
flag.BoolVar(&isEject, "eject", false, "导出内置静态资源")
|
||||
flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本")
|
||||
flag.Parse()
|
||||
bootstrap.Init(confPath)
|
||||
}
|
||||
|
||||
func main() {
|
||||
if isEject {
|
||||
// 开始导出内置静态资源文件
|
||||
bootstrap.Eject()
|
||||
return
|
||||
}
|
||||
|
||||
if scriptName != "" {
|
||||
// 开始运行助手数据库脚本
|
||||
bootstrap.RunScript(scriptName)
|
||||
return
|
||||
}
|
||||
|
||||
api := routers.InitRouter()
|
||||
|
||||
// 如果启用了SSL
|
||||
if conf.SSLConfig.CertPath != "" {
|
||||
go func() {
|
||||
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen)
|
||||
if err := api.RunTLS(conf.SSLConfig.Listen,
|
||||
conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.SSLConfig.Listen, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 如果启用了Unix
|
||||
if conf.UnixConfig.Listen != "" {
|
||||
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
|
||||
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
|
||||
if err := api.Run(conf.SystemConfig.Listen); err != nil {
|
||||
util.Log().Error("无法监听[%s],%s", conf.SystemConfig.Listen, err)
|
||||
|
||||
@@ -5,39 +5,39 @@ import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/driver/onedrive"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/driver/oss"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/driver/upyun"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/qiniu/api.v7/v7/auth/qbox"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SignRequired 验证请求签名
|
||||
func SignRequired() gin.HandlerFunc {
|
||||
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var err error
|
||||
switch c.Request.Method {
|
||||
case "PUT", "POST":
|
||||
err = auth.CheckRequest(auth.General, c.Request)
|
||||
// TODO 生产环境去掉下一行
|
||||
//err = nil
|
||||
case "PUT", "POST", "PATCH":
|
||||
err = auth.CheckRequest(authInstance, c.Request)
|
||||
default:
|
||||
err = auth.CheckURI(auth.General, c.Request.URL)
|
||||
err = auth.CheckURI(authInstance, c.Request.URL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err))
|
||||
c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func WebDAVAuth() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
expectedUser, err := model.GetUserByEmail(username)
|
||||
expectedUser, err := model.GetActiveUserByEmail(username)
|
||||
if err != nil {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
@@ -174,7 +174,7 @@ func QiniuCallbackAuth() gin.HandlerFunc {
|
||||
// 验证key并查找用户
|
||||
resp, user := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: resp.Msg})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -184,12 +184,12 @@ func QiniuCallbackAuth() gin.HandlerFunc {
|
||||
ok, err := mac.VerifyCallback(c.Request)
|
||||
if err != nil {
|
||||
util.Log().Debug("无法验证回调请求,%s", err)
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: "无法验证回调请求"})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: "回调签名无效"})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -204,7 +204,7 @@ func OSSCallbackAuth() gin.HandlerFunc {
|
||||
// 验证key并查找用户
|
||||
resp, _ := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: resp.Msg})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -212,7 +212,7 @@ func OSSCallbackAuth() gin.HandlerFunc {
|
||||
err := oss.VerifyCallbackSignature(c.Request)
|
||||
if err != nil {
|
||||
util.Log().Debug("回调签名验证失败,%s", err)
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: "回调签名验证失败"})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
|
||||
// 验证key并查找用户
|
||||
resp, user := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: resp.Msg})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
|
||||
body, err := ioutil.ReadAll(c.Request.Body)
|
||||
c.Request.Body.Close()
|
||||
if err != nil {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: err.Error()})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
|
||||
// 计算正文MD5
|
||||
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
|
||||
if actualContentMD5 != contentMD5 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: "MD5不一致"})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -267,7 +267,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
|
||||
|
||||
// 对比签名
|
||||
if signature != actualSignature {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: "鉴权失败"})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -283,7 +283,7 @@ func OneDriveCallbackAuth() gin.HandlerFunc {
|
||||
// 验证key并查找用户
|
||||
resp, _ := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: resp.Msg})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -302,7 +302,22 @@ func COSCallbackAuth() gin.HandlerFunc {
|
||||
// 验证key并查找用户
|
||||
resp, _ := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.QiniuCallbackFailed{Error: resp.Msg})
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// S3CallbackAuth Amazon S3回调签名验证
|
||||
func S3CallbackAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 验证key并查找用户
|
||||
resp, _ := uploadCallbackCheck(c)
|
||||
if resp.Code != 0 {
|
||||
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,21 +3,22 @@ package middleware
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/qiniu/api.v7/v7/auth/qbox"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/qiniu/api.v7/v7/auth/qbox"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
@@ -86,19 +87,30 @@ func TestAuthRequired(t *testing.T) {
|
||||
|
||||
func TestSignRequired(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
SignRequiredFunc := SignRequired()
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
SignRequiredFunc := SignRequired(authInstance)
|
||||
|
||||
// 鉴权失败
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
// Sign verify success
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
func TestWebDAVAuth(t *testing.T) {
|
||||
@@ -654,7 +666,7 @@ func TestOneDriveCallbackAuth(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[522]"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[657]"))
|
||||
mock.ExpectQuery("SELECT(.+)policies(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
@@ -699,7 +711,7 @@ func TestCOSCallbackAuth(t *testing.T) {
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[522]"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
|
||||
mock.ExpectQuery("SELECT(.+)policies(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
@@ -747,3 +759,46 @@ func TestIsAdmin(t *testing.T) {
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3CallbackAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := S3CallbackAuth()
|
||||
|
||||
// Callback Key 相关验证失败
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"key", "testUpyunBackRemote"},
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testUpyunBackRemote", nil)
|
||||
AuthFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
cache.Set(
|
||||
"callback_testCallBackUpyun",
|
||||
serializer.UploadSession{
|
||||
UID: 1,
|
||||
PolicyID: 512,
|
||||
VirtualPath: "/",
|
||||
},
|
||||
0,
|
||||
)
|
||||
cache.Deletes([]string{"1"}, "policy_")
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"key", "testCallBackUpyun"},
|
||||
}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
|
||||
AuthFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
|
||||
122
middleware/captcha.go
Normal file
122
middleware/captcha.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/recaptcha"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mojocn/base64Captcha"
|
||||
captcha "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha/v20190722"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type req struct {
|
||||
CaptchaCode string `json:"captchaCode"`
|
||||
Ticket string `json:"ticket"`
|
||||
Randstr string `json:"randstr"`
|
||||
}
|
||||
|
||||
// CaptchaRequired 验证请求签名
|
||||
func CaptchaRequired(configName string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 相关设定
|
||||
options := model.GetSettingByNames(configName,
|
||||
"captcha_type",
|
||||
"captcha_ReCaptchaSecret",
|
||||
"captcha_TCaptcha_SecretId",
|
||||
"captcha_TCaptcha_SecretKey",
|
||||
"captcha_TCaptcha_CaptchaAppId",
|
||||
"captcha_TCaptcha_AppSecretKey")
|
||||
// 检查验证码
|
||||
isCaptchaRequired := model.IsTrueVal(options[configName])
|
||||
|
||||
if isCaptchaRequired {
|
||||
var service req
|
||||
bodyCopy := new(bytes.Buffer)
|
||||
_, err := io.Copy(bodyCopy, c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
bodyData := bodyCopy.Bytes()
|
||||
err = json.Unmarshal(bodyData, &service)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = ioutil.NopCloser(bytes.NewReader(bodyData))
|
||||
switch options["captcha_type"] {
|
||||
case "normal":
|
||||
captchaID := util.GetSession(c, "captchaID")
|
||||
util.DeleteSession(c, "captchaID")
|
||||
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
|
||||
c.JSON(200, serializer.ParamErr("验证码错误", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "recaptcha":
|
||||
reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
|
||||
c.Abort()
|
||||
break
|
||||
}
|
||||
|
||||
err = reCAPTCHA.Verify(service.CaptchaCode)
|
||||
if err != nil {
|
||||
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
|
||||
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
case "tcaptcha":
|
||||
credential := common.NewCredential(
|
||||
options["captcha_TCaptcha_SecretId"],
|
||||
options["captcha_TCaptcha_SecretKey"],
|
||||
)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.Endpoint = "captcha.tencentcloudapi.com"
|
||||
client, _ := captcha.NewClient(credential, "", cpf)
|
||||
request := captcha.NewDescribeCaptchaResultRequest()
|
||||
request.CaptchaType = common.Uint64Ptr(9)
|
||||
appid, _ := strconv.Atoi(options["captcha_TCaptcha_CaptchaAppId"])
|
||||
request.CaptchaAppId = common.Uint64Ptr(uint64(appid))
|
||||
request.AppSecretKey = common.StringPtr(options["captcha_TCaptcha_AppSecretKey"])
|
||||
request.Ticket = common.StringPtr(service.Ticket)
|
||||
request.Randstr = common.StringPtr(service.Randstr)
|
||||
request.UserIp = common.StringPtr(c.ClientIP())
|
||||
response, err := client.DescribeCaptchaResult(request)
|
||||
if err != nil {
|
||||
util.Log().Warning("TCaptcha 验证错误, %s", err)
|
||||
c.Abort()
|
||||
break
|
||||
}
|
||||
|
||||
if *response.Response.CaptchaCode != int64(1) {
|
||||
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
177
middleware/captcha_test.go
Normal file
177
middleware/captcha_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type errReader int
|
||||
|
||||
func (errReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("test error")
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_General(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 未启用验证码
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "0",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// body 无法读取
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", errReader(1))
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// body JSON 解析失败
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "1",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("123"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Normal(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "normal",
|
||||
"captcha_ReCaptchaSecret": "1",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
Session("233")(c)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Recaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 无法初始化reCaptcha实例
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 验证码错误
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "recaptcha",
|
||||
"captcha_ReCaptchaSecret": "233",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCaptchaRequired_Tcaptcha(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 验证出错
|
||||
{
|
||||
cache.SetSettings(map[string]string{
|
||||
"login_captcha": "1",
|
||||
"captcha_type": "tcaptcha",
|
||||
"captcha_ReCaptchaSecret": "",
|
||||
"captcha_TCaptcha_SecretId": "1",
|
||||
"captcha_TCaptcha_SecretKey": "1",
|
||||
"captcha_TCaptcha_CaptchaAppId": "1",
|
||||
"captcha_TCaptcha_AppSecretKey": "1",
|
||||
}, "setting_")
|
||||
TestFunc := CaptchaRequired("login_captcha")
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
r := bytes.NewReader([]byte("{}"))
|
||||
c.Request, _ = http.NewRequest("GET", "/", r)
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
61
middleware/cluster.go
Normal file
61
middleware/cluster.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
|
||||
func MasterMetadata() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("MasterSiteID", c.GetHeader("X-Cr-Site-Id"))
|
||||
c.Set("MasterSiteURL", c.GetHeader("X-Cr-Site-Url"))
|
||||
c.Set("MasterVersion", c.GetHeader("X-Cr-Cloudreve-Version"))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
|
||||
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := clusterController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("MasterAria2Instance", caller)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeID, err := strconv.ParseUint(c.GetHeader("X-Cr-Node-Id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
slaveNode := nodePool.GetNodeByID(uint(nodeID))
|
||||
if slaveNode == nil {
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
SignRequired(slaveNode.MasterAuthInstance())(c)
|
||||
|
||||
}
|
||||
}
|
||||
120
middleware/cluster_test.go
Normal file
120
middleware/cluster_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMasterMetadata(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
masterMetaDataFunc := MasterMetadata()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
c.Request.Header = map[string][]string{
|
||||
"X-Cr-Site-Id": {"expectedSiteID"},
|
||||
"X-Cr-Site-Url": {"expectedSiteURL"},
|
||||
"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
|
||||
}
|
||||
masterMetaDataFunc(c)
|
||||
siteID, _ := c.Get("MasterSiteID")
|
||||
siteURL, _ := c.Get("MasterSiteURL")
|
||||
siteVersion, _ := c.Get("MasterVersion")
|
||||
|
||||
a.Equal("expectedSiteID", siteID.(string))
|
||||
a.Equal("expectedSiteURL", siteURL.(string))
|
||||
a.Equal("expectedMasterVersion", siteVersion.(string))
|
||||
}
|
||||
|
||||
func TestSlaveRPCSignRequired(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
np := &cluster.NodePool{}
|
||||
np.Init()
|
||||
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// id parse failed
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "unknown")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// node id not exist
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "38")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("")}
|
||||
np.Add(&model.Node{Model: gorm.Model{
|
||||
ID: 38,
|
||||
}})
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
c.Request.Header.Set("X-Cr-Node-Id", "38")
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseSlaveAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
// MasterSiteID not set
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// Cannot get aria2 instances
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Set("MasterSiteID", "expectedSiteID")
|
||||
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
testController.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Success
|
||||
{
|
||||
testController := &controllermock.SlaveControllerMock{}
|
||||
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Set("MasterSiteID", "expectedSiteID")
|
||||
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
|
||||
useSlaveAria2InstanceFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
res, _ := c.Get("MasterAria2Instance")
|
||||
a.NotNil(res)
|
||||
testController.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
69
middleware/frontend.go
Normal file
69
middleware/frontend.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/bootstrap"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FrontendFileHandler 前端静态文件处理
|
||||
func FrontendFileHandler() gin.HandlerFunc {
|
||||
ignoreFunc := func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
|
||||
if bootstrap.StaticFS == nil {
|
||||
return ignoreFunc
|
||||
}
|
||||
|
||||
// 读取index.html
|
||||
file, err := bootstrap.StaticFS.Open("/index.html")
|
||||
if err != nil {
|
||||
util.Log().Warning("静态文件[index.html]不存在,可能会影响首页展示")
|
||||
return ignoreFunc
|
||||
}
|
||||
|
||||
fileContentBytes, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
util.Log().Warning("静态文件[index.html]读取失败,可能会影响首页展示")
|
||||
return ignoreFunc
|
||||
}
|
||||
fileContent := string(fileContentBytes)
|
||||
|
||||
fileServer := http.FileServer(bootstrap.StaticFS)
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// API 跳过
|
||||
if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 不存在的路径和index.html均返回index.html
|
||||
if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) {
|
||||
// 读取、替换站点设置
|
||||
options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript",
|
||||
"pwa_small_icon")
|
||||
finalHTML := util.Replace(map[string]string{
|
||||
"{siteName}": options["siteName"],
|
||||
"{siteDes}": options["siteDes"],
|
||||
"{siteScript}": options["siteScript"],
|
||||
"{pwa_small_icon}": options["pwa_small_icon"],
|
||||
}, fileContent)
|
||||
|
||||
c.Header("Content-Type", "text/html")
|
||||
c.String(200, finalHTML)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 存在的静态文件
|
||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
144
middleware/frontend_test.go
Normal file
144
middleware/frontend_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/bootstrap"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type StaticMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m StaticMock) Open(name string) (http.File, error) {
|
||||
args := m.Called(name)
|
||||
return args.Get(0).(http.File), args.Error(1)
|
||||
}
|
||||
|
||||
func (m StaticMock) Exists(prefix string, filepath string) bool {
|
||||
args := m.Called(prefix, filepath)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func TestFrontendFileHandler(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 静态资源未加载
|
||||
{
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// index.html 不存在
|
||||
{
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(&os.File{}, errors.New("error"))
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// index.html 读取失败
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功且命中
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/", nil)
|
||||
|
||||
cache.Set("setting_siteName", "cloudreve", 0)
|
||||
cache.Set("setting_siteKeywords", "cloudreve", 0)
|
||||
cache.Set("setting_siteScript", "cloudreve", 0)
|
||||
cache.Set("setting_pwa_small_icon", "cloudreve", 0)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// 成功且命中静态文件
|
||||
{
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
testStatic.On("Exists", "/", "/2").
|
||||
Return(true)
|
||||
testStatic.On("Open", "/2").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", "/2", nil)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.True(c.IsAborted())
|
||||
testStatic.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// API 相关跳过
|
||||
{
|
||||
for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} {
|
||||
file, _ := util.CreatNestedFile("tests/index.html")
|
||||
defer file.Close()
|
||||
testStatic := &StaticMock{}
|
||||
bootstrap.StaticFS = testStatic
|
||||
testStatic.On("Open", "/index.html").
|
||||
Return(file, nil)
|
||||
TestFunc := FrontendFileHandler()
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("GET", reqPath, nil)
|
||||
|
||||
TestFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMockHelper(t *testing.T) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/hashid"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/hashid"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHashID(t *testing.T) {
|
||||
@@ -62,7 +63,6 @@ func TestIsFunctionEnabled(t *testing.T) {
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
// 开启
|
||||
@@ -72,7 +72,6 @@ func TestIsFunctionEnabled(t *testing.T) {
|
||||
c.Params = []gin.Param{}
|
||||
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
|
||||
TestFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
"github.com/gin-contrib/sessions/redis"
|
||||
@@ -17,7 +18,7 @@ func Session(secret string) gin.HandlerFunc {
|
||||
// Redis设置不为空,且非测试模式时使用Redis
|
||||
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
|
||||
var err error
|
||||
Store, err = redis.NewStoreWithDB(10, "tcp", conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB, []byte(secret))
|
||||
Store, err = redis.NewStoreWithDB(10, conf.RedisConfig.Network, conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB, []byte(secret))
|
||||
if err != nil {
|
||||
util.Log().Panic("无法连接到 Redis:%s", err)
|
||||
}
|
||||
@@ -32,3 +33,24 @@ func Session(secret string) gin.HandlerFunc {
|
||||
Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"})
|
||||
return sessions.Sessions("cloudreve-session", Store)
|
||||
}
|
||||
|
||||
// CSRFInit 初始化CSRF标记
|
||||
func CSRFInit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
util.SetSession(c, map[string]interface{}{"CSRF": true})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFCheck 检查CSRF标记
|
||||
func CSRFCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if check, ok := util.GetSession(c, "CSRF").(bool); ok && check {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSession(t *testing.T) {
|
||||
@@ -21,6 +25,7 @@ func TestSession(t *testing.T) {
|
||||
asserts.Panics(func() {
|
||||
Session("2333")
|
||||
})
|
||||
conf.RedisConfig.Server = ""
|
||||
}
|
||||
|
||||
}
|
||||
@@ -28,3 +33,41 @@ func TestSession(t *testing.T) {
|
||||
func emptyFunc() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {}
|
||||
}
|
||||
|
||||
func TestCSRFInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
asserts.True(util.GetSession(c, "CSRF").(bool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFCheck(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
sessionFunc := Session("233")
|
||||
|
||||
// 通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFInit()(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
// 未通过检查
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
sessionFunc(c)
|
||||
CSRFCheck()(c)
|
||||
asserts.True(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShareAvailable(t *testing.T) {
|
||||
|
||||
@@ -2,8 +2,9 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
@@ -17,12 +18,13 @@ type Download struct {
|
||||
DownloadedSize uint64 // 文件大小
|
||||
GID string `gorm:"size:32,index:gid"` // 任务ID
|
||||
Speed int // 下载速度
|
||||
Parent string `gorm:"type:text"` // 存储目录
|
||||
Attrs string `gorm:"type:text"` // 任务状态属性
|
||||
Error string `gorm:"type:text"` // 错误描述
|
||||
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
||||
Parent string `gorm:"type:text"` // 存储目录
|
||||
Attrs string `gorm:"size:4294967295"` // 任务状态属性
|
||||
Error string `gorm:"type:text"` // 错误描述
|
||||
Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径
|
||||
UserID uint // 发起者UID
|
||||
TaskID uint // 对应的转存任务ID
|
||||
NodeID uint // 处理任务的节点ID
|
||||
|
||||
// 关联模型
|
||||
User *User `gorm:"PRELOAD:false,association_autoupdate:false"`
|
||||
@@ -108,3 +110,18 @@ func (task *Download) GetOwner() *User {
|
||||
}
|
||||
return task.User
|
||||
}
|
||||
|
||||
// Delete 删除离线下载记录
|
||||
func (download *Download) Delete() error {
|
||||
return DB.Model(download).Delete(download).Error
|
||||
}
|
||||
|
||||
// GetNodeID 返回任务所属节点ID
|
||||
func (task *Download) GetNodeID() uint {
|
||||
// 兼容3.4版本之前生成的下载记录
|
||||
if task.NodeID == 0 {
|
||||
return 1
|
||||
}
|
||||
|
||||
return task.NodeID
|
||||
}
|
||||
|
||||
@@ -161,3 +161,30 @@ func TestGetDownloadsByStatusAndUser(t *testing.T) {
|
||||
asserts.Len(res, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownload_Delete(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
share := Download{}
|
||||
|
||||
{
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := share.Delete()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDownload_GetNodeID(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
record := Download{}
|
||||
|
||||
// compatible with 3.4
|
||||
a.EqualValues(1, record.GetNodeID())
|
||||
|
||||
record.NodeID = 5
|
||||
a.EqualValues(5, record.GetNodeID())
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// File 文件
|
||||
@@ -89,7 +90,7 @@ func GetFilesByKeywords(uid uint, keywords ...interface{}) ([]File, error) {
|
||||
|
||||
// 生成查询条件
|
||||
for i := 0; i < len(keywords); i++ {
|
||||
conditions += "LOWER(name) like ?"
|
||||
conditions += "name like ?"
|
||||
if i != len(keywords)-1 {
|
||||
conditions += " or "
|
||||
}
|
||||
@@ -185,17 +186,17 @@ func (file *File) Rename(new string) error {
|
||||
|
||||
// UpdatePicInfo 更新文件的图像信息
|
||||
func (file *File) UpdatePicInfo(value string) error {
|
||||
return DB.Model(&file).Update("pic_info", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("pic_info", value).Error
|
||||
}
|
||||
|
||||
// UpdateSize 更新文件的大小信息
|
||||
func (file *File) UpdateSize(value uint64) error {
|
||||
return DB.Model(&file).Update("size", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value).Error
|
||||
}
|
||||
|
||||
// UpdateSourceName 更新文件的源文件名
|
||||
func (file *File) UpdateSourceName(value string) error {
|
||||
return DB.Model(&file).Update("source_name", value).Error
|
||||
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -2,10 +2,11 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Folder 目录
|
||||
@@ -43,6 +44,26 @@ func (folder *Folder) GetChild(name string) (*Folder, error) {
|
||||
return &resFolder, err
|
||||
}
|
||||
|
||||
// TraceRoot 向上递归查找父目录
|
||||
func (folder *Folder) TraceRoot() error {
|
||||
if folder.ParentID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parentFolder Folder
|
||||
err := DB.
|
||||
Where("id = ? AND owner_id = ?", folder.ParentID, folder.OwnerID).
|
||||
First(&parentFolder).Error
|
||||
|
||||
if err == nil {
|
||||
err := parentFolder.TraceRoot()
|
||||
folder.Position = path.Join(parentFolder.Position, parentFolder.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChildFolder 查找子目录
|
||||
func (folder *Folder) GetChildFolder() ([]Folder, error) {
|
||||
var folders []Folder
|
||||
|
||||
@@ -2,12 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFolder_Create(t *testing.T) {
|
||||
@@ -529,3 +530,37 @@ func TestFolder_FileInfoInterface(t *testing.T) {
|
||||
asserts.True(folder.IsDir())
|
||||
asserts.Equal("/test", folder.GetPosition())
|
||||
}
|
||||
|
||||
func TestTraceRoot(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
var parentId uint
|
||||
parentId = 5
|
||||
folder := Folder{
|
||||
ParentID: &parentId,
|
||||
OwnerID: 1,
|
||||
Name: "test_name",
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/"))
|
||||
asserts.NoError(folder.TraceRoot())
|
||||
asserts.Equal("/parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 出现错误
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
|
||||
WillReturnError(errors.New("error"))
|
||||
asserts.Error(folder.TraceRoot())
|
||||
asserts.Equal("parent", folder.Position)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
)
|
||||
|
||||
@@ -28,18 +31,35 @@ func Init() {
|
||||
// 测试模式下,使用内存数据库
|
||||
db, err = gorm.Open("sqlite3", ":memory:")
|
||||
} else {
|
||||
if conf.DatabaseConfig.Type == "UNSET" {
|
||||
// 未指定数据库时,使用SQLite
|
||||
db, err = gorm.Open("sqlite3", util.RelativePath("cloudreve.db"))
|
||||
} else {
|
||||
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
|
||||
switch conf.DatabaseConfig.Type {
|
||||
case "UNSET", "sqlite", "sqlite3":
|
||||
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite3 数据库
|
||||
db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile))
|
||||
case "postgres":
|
||||
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
|
||||
conf.DatabaseConfig.Host,
|
||||
conf.DatabaseConfig.User,
|
||||
conf.DatabaseConfig.Password,
|
||||
conf.DatabaseConfig.Name,
|
||||
conf.DatabaseConfig.Port))
|
||||
case "mysql", "mssql":
|
||||
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||
conf.DatabaseConfig.User,
|
||||
conf.DatabaseConfig.Password,
|
||||
conf.DatabaseConfig.Host,
|
||||
conf.DatabaseConfig.Name))
|
||||
conf.DatabaseConfig.Port,
|
||||
conf.DatabaseConfig.Name,
|
||||
conf.DatabaseConfig.Charset))
|
||||
default:
|
||||
util.Log().Panic("不支持数据库类型: %s", conf.DatabaseConfig.Type)
|
||||
}
|
||||
}
|
||||
|
||||
//db.SetLogger(util.Log())
|
||||
if err != nil {
|
||||
util.Log().Panic("连接数据库不成功, %s", err)
|
||||
}
|
||||
|
||||
// 处理表前缀
|
||||
gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string {
|
||||
return conf.DatabaseConfig.TablePrefix + defaultTableName
|
||||
@@ -52,11 +72,6 @@ func Init() {
|
||||
db.LogMode(false)
|
||||
}
|
||||
|
||||
//db.SetLogger(util.Log())
|
||||
if err != nil {
|
||||
util.Log().Panic("连接数据库不成功, %s", err)
|
||||
}
|
||||
|
||||
//设置连接池
|
||||
//空闲
|
||||
db.DB().SetMaxIdleConns(50)
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/jinzhu/gorm"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 是否需要迁移
|
||||
@@ -34,8 +40,9 @@ func migration() {
|
||||
if conf.DatabaseConfig.Type == "mysql" {
|
||||
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
|
||||
}
|
||||
|
||||
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
|
||||
&Task{}, &Download{}, &Tag{}, &Webdav{})
|
||||
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{})
|
||||
|
||||
// 创建初始存储策略
|
||||
addDefaultPolicy()
|
||||
@@ -46,9 +53,15 @@ func migration() {
|
||||
// 创建初始管理员账户
|
||||
addDefaultUser()
|
||||
|
||||
// 创建初始节点
|
||||
addDefaultNode()
|
||||
|
||||
// 向设置数据表添加初始设置
|
||||
addDefaultSettings()
|
||||
|
||||
// 执行数据库升级脚本
|
||||
execUpgradeScripts()
|
||||
|
||||
util.Log().Info("数据库初始化结束")
|
||||
|
||||
}
|
||||
@@ -73,14 +86,19 @@ func addDefaultPolicy() {
|
||||
}
|
||||
|
||||
func addDefaultSettings() {
|
||||
siteID, _ := uuid.NewV4()
|
||||
|
||||
defaultSettings := []Setting{
|
||||
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
|
||||
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
|
||||
{Name: "siteICPId", Value: ``, Type: "basic"},
|
||||
{Name: "register_enabled", Value: `1`, Type: "register"},
|
||||
{Name: "default_group", Value: `2`, Type: "register"},
|
||||
{Name: "siteKeywords", Value: `网盘,网盘`, Type: "basic"},
|
||||
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
|
||||
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
|
||||
{Name: "siteScript", Value: ``, Type: "basic"},
|
||||
{Name: "siteID", Value: siteID.String(), Type: "basic"},
|
||||
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
|
||||
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
|
||||
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
|
||||
@@ -89,6 +107,7 @@ func addDefaultSettings() {
|
||||
{Name: "replyTo", Value: `abslant@126.com`, Type: "mail"},
|
||||
{Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"},
|
||||
{Name: "smtpPass", Value: ``, Type: "mail"},
|
||||
{Name: "smtpEncryption", Value: `0`, Type: "mail"},
|
||||
{Name: "maxEditSize", Value: `4194304`, Type: "file_edit"},
|
||||
{Name: "archive_timeout", Value: `60`, Type: "timeout"},
|
||||
{Name: "download_timeout", Value: `60`, Type: "timeout"},
|
||||
@@ -97,11 +116,16 @@ func addDefaultSettings() {
|
||||
{Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"},
|
||||
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
|
||||
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
|
||||
{Name: "slave_node_retry", Value: `3`, Type: "slave"},
|
||||
{Name: "slave_ping_interval", Value: `60`, Type: "slave"},
|
||||
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
|
||||
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
|
||||
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
|
||||
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
|
||||
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
|
||||
{Name: "aria2_call_timeout", Value: `5`, Type: "timeout"},
|
||||
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
|
||||
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
|
||||
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
|
||||
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
|
||||
{Name: "login_captcha", Value: `0`, Type: "login"},
|
||||
{Name: "reg_captcha", Value: `0`, Type: "login"},
|
||||
@@ -126,11 +150,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
|
||||
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
|
||||
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
|
||||
{Name: "aria2_token", Value: ``, Type: "aria2"},
|
||||
{Name: "aria2_rpcurl", Value: ``, Type: "aria2"},
|
||||
{Name: "aria2_temp_path", Value: ``, Type: "aria2"},
|
||||
{Name: "aria2_options", Value: `{}`, Type: "aria2"},
|
||||
{Name: "aria2_interval", Value: `60`, Type: "aria2"},
|
||||
{Name: "max_worker_num", Value: `10`, Type: "task"},
|
||||
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
|
||||
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
|
||||
@@ -144,6 +163,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "share_view_method", Value: "list", Type: "view"},
|
||||
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
|
||||
{Name: "authn_enabled", Value: "0", Type: "authn"},
|
||||
{Name: "captcha_type", Value: "normal", Type: "captcha"},
|
||||
{Name: "captcha_height", Value: "60", Type: "captcha"},
|
||||
{Name: "captcha_width", Value: "240", Type: "captcha"},
|
||||
{Name: "captcha_mode", Value: "3", Type: "captcha"},
|
||||
@@ -155,6 +175,12 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"},
|
||||
{Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"},
|
||||
{Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"},
|
||||
{Name: "captcha_ReCaptchaKey", Value: "defaultKey", Type: "captcha"},
|
||||
{Name: "captcha_ReCaptchaSecret", Value: "defaultSecret", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_CaptchaAppId", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_AppSecretKey", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_SecretId", Value: "", Type: "captcha"},
|
||||
{Name: "captcha_TCaptcha_SecretKey", Value: "", Type: "captcha"},
|
||||
{Name: "thumb_width", Value: "400", Type: "thumb"},
|
||||
{Name: "thumb_height", Value: "300", Type: "thumb"},
|
||||
{Name: "pwa_small_icon", Value: "/static/img/favicon.ico", Type: "pwa"},
|
||||
@@ -163,6 +189,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
|
||||
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
|
||||
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
|
||||
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
|
||||
}
|
||||
|
||||
for _, value := range defaultSettings {
|
||||
@@ -253,3 +280,36 @@ func addDefaultUser() {
|
||||
util.Log().Info("初始管理员密码:" + c.Sprint(password))
|
||||
}
|
||||
}
|
||||
|
||||
func addDefaultNode() {
|
||||
_, err := GetNodeByID(1)
|
||||
|
||||
if gorm.IsRecordNotFoundError(err) {
|
||||
defaultAdminGroup := Node{
|
||||
Name: "主机(本机)",
|
||||
Status: NodeActive,
|
||||
Type: MasterNodeType,
|
||||
Aria2OptionsSerialized: Aria2Option{
|
||||
Interval: 10,
|
||||
Timeout: 10,
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
|
||||
util.Log().Panic("无法创建初始节点记录, %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func execUpgradeScripts() {
|
||||
s := invoker.ListPrefix("UpgradeTo")
|
||||
versions := make([]*version.Version, len(s))
|
||||
for i, raw := range s {
|
||||
v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo"))
|
||||
versions[i] = v
|
||||
}
|
||||
sort.Sort(version.Collection(versions))
|
||||
|
||||
for i := 0; i < len(versions); i++ {
|
||||
invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigration(t *testing.T) {
|
||||
|
||||
91
models/node.go
Normal file
91
models/node.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Node 从机节点信息模型
|
||||
type Node struct {
|
||||
gorm.Model
|
||||
Status NodeStatus // 节点状态
|
||||
Name string // 节点别名
|
||||
Type ModelType // 节点状态
|
||||
Server string // 服务器地址
|
||||
SlaveKey string `gorm:"type:text"` // 主->从 通信密钥
|
||||
MasterKey string `gorm:"type:text"` // 从->主 通信密钥
|
||||
Aria2Enabled bool // 是否支持用作离线下载节点
|
||||
Aria2Options string `gorm:"type:text"` // 离线下载配置
|
||||
Rank int // 负载均衡权重
|
||||
|
||||
// 数据库忽略字段
|
||||
Aria2OptionsSerialized Aria2Option `gorm:"-"`
|
||||
}
|
||||
|
||||
// Aria2Option 非公有的Aria2配置属性
|
||||
type Aria2Option struct {
|
||||
// RPC 服务器地址
|
||||
Server string `json:"server,omitempty"`
|
||||
// RPC 密钥
|
||||
Token string `json:"token,omitempty"`
|
||||
// 临时下载目录
|
||||
TempPath string `json:"temp_path,omitempty"`
|
||||
// 附加下载配置
|
||||
Options string `json:"options,omitempty"`
|
||||
// 下载监控间隔
|
||||
Interval int `json:"interval,omitempty"`
|
||||
// RPC API 请求超时
|
||||
Timeout int `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type NodeStatus int
|
||||
type ModelType int
|
||||
|
||||
const (
|
||||
NodeActive NodeStatus = iota
|
||||
NodeSuspend
|
||||
)
|
||||
|
||||
const (
|
||||
SlaveNodeType ModelType = iota
|
||||
MasterNodeType
|
||||
)
|
||||
|
||||
// GetNodeByID 用ID获取节点
|
||||
func GetNodeByID(ID interface{}) (Node, error) {
|
||||
var node Node
|
||||
result := DB.First(&node, ID)
|
||||
return node, result.Error
|
||||
}
|
||||
|
||||
// GetNodesByStatus 根据给定状态获取节点
|
||||
func GetNodesByStatus(status ...NodeStatus) ([]Node, error) {
|
||||
var nodes []Node
|
||||
result := DB.Where("status in (?)", status).Find(&nodes)
|
||||
return nodes, result.Error
|
||||
}
|
||||
|
||||
// AfterFind 找到节点后的钩子
|
||||
func (node *Node) AfterFind() (err error) {
|
||||
// 解析离线下载设置到 Aria2OptionsSerialized
|
||||
if node.Aria2Options != "" {
|
||||
err = json.Unmarshal([]byte(node.Aria2Options), &node.Aria2OptionsSerialized)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// BeforeSave Save策略前的钩子
|
||||
func (node *Node) BeforeSave() (err error) {
|
||||
optionsValue, err := json.Marshal(&node.Aria2OptionsSerialized)
|
||||
node.Aria2Options = string(optionsValue)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetStatus 设置节点启用状态
|
||||
func (node *Node) SetStatus(status NodeStatus) error {
|
||||
node.Status = status
|
||||
return DB.Model(node).Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
}).Error
|
||||
}
|
||||
64
models/node_test.go
Normal file
64
models/node_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetNodeByID(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
res, err := GetNodeByID(1)
|
||||
a.NoError(err)
|
||||
a.EqualValues(1, res.ID)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetNodesByStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive))
|
||||
res, err := GetNodesByStatus(NodeActive)
|
||||
a.NoError(err)
|
||||
a.Len(res, 1)
|
||||
a.EqualValues(NodeActive, res[0].Status)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestNode_AfterFind(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
node := &Node{}
|
||||
|
||||
// No aria2 options
|
||||
{
|
||||
a.NoError(node.AfterFind())
|
||||
}
|
||||
|
||||
// with aria2 options
|
||||
{
|
||||
node.Aria2Options = `{"timeout":1}`
|
||||
a.NoError(node.AfterFind())
|
||||
a.Equal(1, node.Aria2OptionsSerialized.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNode_BeforeSave(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
node := &Node{}
|
||||
|
||||
node.Aria2OptionsSerialized.Timeout = 1
|
||||
a.NoError(node.BeforeSave())
|
||||
a.Contains(node.Aria2Options, "1")
|
||||
}
|
||||
|
||||
func TestNode_SetStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
node := &Node{}
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
a.NoError(node.SetStatus(NodeActive))
|
||||
a.Equal(NodeActive, node.Status)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -3,13 +3,17 @@ package model
|
||||
import (
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Policy 存储策略
|
||||
@@ -33,6 +37,7 @@ type Policy struct {
|
||||
|
||||
// 数据库忽略字段
|
||||
OptionsSerialized PolicyOption `gorm:"-"`
|
||||
MasterID string `gorm:"-"`
|
||||
}
|
||||
|
||||
// PolicyOption 非公有的存储策略属性
|
||||
@@ -43,9 +48,27 @@ type PolicyOption struct {
|
||||
FileType []string `json:"file_type"`
|
||||
// MimeType
|
||||
MimeType string `json:"mimetype"`
|
||||
|
||||
// OdRedirect Onedrive重定向地址
|
||||
// OdRedirect Onedrive 重定向地址
|
||||
OdRedirect string `json:"od_redirect,omitempty"`
|
||||
// OdProxy Onedrive 反代地址
|
||||
OdProxy string `json:"od_proxy,omitempty"`
|
||||
// OdDriver OneDrive 驱动器定位符
|
||||
OdDriver string `json:"od_driver,omitempty"`
|
||||
// Region 区域代码
|
||||
Region string `json:"region,omitempty"`
|
||||
// ServerSideEndpoint 服务端请求使用的 Endpoint,为空时使用 Policy.Server 字段
|
||||
ServerSideEndpoint string `json:"server_side_endpoint,omitempty"`
|
||||
}
|
||||
|
||||
var thumbSuffix = map[string][]string{
|
||||
"local": {},
|
||||
"qiniu": {".psd", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
|
||||
"oss": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
|
||||
"cos": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
|
||||
"upyun": {".svg", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
|
||||
"s3": {},
|
||||
"remote": {},
|
||||
"onedrive": {"*"},
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -178,6 +201,17 @@ func (policy *Policy) IsDirectlyPreview() bool {
|
||||
return policy.Type == "local"
|
||||
}
|
||||
|
||||
// IsThumbExist 给定文件名,返回此存储策略下是否可能存在缩略图
|
||||
func (policy *Policy) IsThumbExist(name string) bool {
|
||||
if list, ok := thumbSuffix[policy.Type]; ok {
|
||||
if len(list) == 1 && list[0] == "*" {
|
||||
return true
|
||||
}
|
||||
return util.ContainsString(list, strings.ToLower(filepath.Ext(name)))
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTransitUpload 返回此策略上传给定size文件时是否需要服务端中转
|
||||
func (policy *Policy) IsTransitUpload(size uint64) bool {
|
||||
if policy.Type == "local" {
|
||||
@@ -199,9 +233,9 @@ func (policy *Policy) IsThumbGenerateNeeded() bool {
|
||||
return policy.Type == "local"
|
||||
}
|
||||
|
||||
// IsMockThumbNeeded 返回此策略是否需要在上传后默认当图像文件
|
||||
func (policy *Policy) IsMockThumbNeeded() bool {
|
||||
return policy.Type == "onedrive"
|
||||
// CanStructureBeListed 返回存储策略是否能被前台列物理目录
|
||||
func (policy *Policy) CanStructureBeListed() bool {
|
||||
return policy.Type != "local" && policy.Type != "remote"
|
||||
}
|
||||
|
||||
// GetUploadURL 获取文件上传服务API地址
|
||||
@@ -211,7 +245,7 @@ func (policy *Policy) GetUploadURL() string {
|
||||
return policy.Server
|
||||
}
|
||||
|
||||
var controller *url.URL
|
||||
controller, _ := url.Parse("")
|
||||
switch policy.Type {
|
||||
case "local", "onedrive":
|
||||
return "/api/v3/file/upload"
|
||||
@@ -223,20 +257,34 @@ func (policy *Policy) GetUploadURL() string {
|
||||
return policy.Server
|
||||
case "upyun":
|
||||
return "https://v0.api.upyun.com/" + policy.BucketName
|
||||
default:
|
||||
controller, _ = url.Parse("")
|
||||
case "s3":
|
||||
if policy.Server == "" {
|
||||
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/", policy.BucketName,
|
||||
policy.OptionsSerialized.Region)
|
||||
}
|
||||
|
||||
if !strings.Contains(policy.Server, policy.BucketName) {
|
||||
controller, _ = url.Parse("/" + policy.BucketName)
|
||||
}
|
||||
}
|
||||
|
||||
return server.ResolveReference(controller).String()
|
||||
}
|
||||
|
||||
// UpdateAccessKey 更新 AccessKey
|
||||
func (policy *Policy) UpdateAccessKey(key string) error {
|
||||
policy.AccessKey = key
|
||||
// SaveAndClearCache 更新并清理缓存
|
||||
func (policy *Policy) SaveAndClearCache() error {
|
||||
err := DB.Save(policy).Error
|
||||
policy.ClearCache()
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveAndClearCache 更新并清理缓存
|
||||
func (policy *Policy) UpdateAccessKeyAndClearCache(s string) error {
|
||||
err := DB.Model(policy).UpdateColumn("access_key", s).Error
|
||||
policy.ClearCache()
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearCache 清空policy缓存
|
||||
func (policy *Policy) ClearCache() {
|
||||
cache.Deletes([]string{strconv.FormatUint(uint64(policy.ID), 10)}, "policy_")
|
||||
|
||||
@@ -2,13 +2,14 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetPolicyByID(t *testing.T) {
|
||||
@@ -209,6 +210,28 @@ func TestPolicy_GetUploadURL(t *testing.T) {
|
||||
asserts.Equal("http://127.0.0.1", policy.GetUploadURL())
|
||||
}
|
||||
|
||||
// S3 未填写自动生成
|
||||
{
|
||||
policy := Policy{
|
||||
Type: "s3",
|
||||
Server: "",
|
||||
BucketName: "bucket",
|
||||
OptionsSerialized: PolicyOption{Region: "us-east"},
|
||||
}
|
||||
asserts.Equal("https://bucket.s3.us-east.amazonaws.com/", policy.GetUploadURL())
|
||||
}
|
||||
|
||||
// s3 自己指定
|
||||
{
|
||||
policy := Policy{
|
||||
Type: "s3",
|
||||
Server: "https://s3.us-east.amazonaws.com/",
|
||||
BucketName: "bucket",
|
||||
OptionsSerialized: PolicyOption{Region: "us-east"},
|
||||
}
|
||||
asserts.Equal("https://s3.us-east.amazonaws.com/bucket", policy.GetUploadURL())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPolicy_IsPathGenerateNeeded(t *testing.T) {
|
||||
@@ -234,7 +257,8 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := policy.UpdateAccessKey("123")
|
||||
policy.AccessKey = "123"
|
||||
err := policy.SaveAndClearCache()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
@@ -242,13 +266,79 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
|
||||
func TestPolicy_Props(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
policy := Policy{Type: "onedrive"}
|
||||
asserts.True(policy.IsMockThumbNeeded())
|
||||
asserts.False(policy.IsThumbGenerateNeeded())
|
||||
asserts.True(policy.IsPathGenerateNeeded())
|
||||
asserts.True(policy.IsTransitUpload(4))
|
||||
asserts.False(policy.IsTransitUpload(5 * 1024 * 1024))
|
||||
asserts.True(policy.CanStructureBeListed())
|
||||
policy.Type = "local"
|
||||
asserts.False(policy.IsMockThumbNeeded())
|
||||
asserts.True(policy.IsThumbGenerateNeeded())
|
||||
asserts.True(policy.IsPathGenerateNeeded())
|
||||
asserts.False(policy.CanStructureBeListed())
|
||||
}
|
||||
|
||||
func TestPolicy_IsThumbExist(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
expect bool
|
||||
policy string
|
||||
}{
|
||||
{
|
||||
"1.png",
|
||||
false,
|
||||
"unknown",
|
||||
},
|
||||
{
|
||||
"1.png",
|
||||
false,
|
||||
"local",
|
||||
},
|
||||
{
|
||||
"1.png",
|
||||
true,
|
||||
"cos",
|
||||
},
|
||||
{
|
||||
"1",
|
||||
false,
|
||||
"cos",
|
||||
},
|
||||
{
|
||||
"1.txt.png",
|
||||
true,
|
||||
"cos",
|
||||
},
|
||||
{
|
||||
"1.png.txt",
|
||||
false,
|
||||
"cos",
|
||||
},
|
||||
{
|
||||
"1",
|
||||
true,
|
||||
"onedrive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
policy := Policy{Type: testCase.policy}
|
||||
asserts.Equal(testCase.expect, policy.IsThumbExist(testCase.name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
cache.Set("policy_1331", Policy{}, 3600)
|
||||
p := &Policy{}
|
||||
p.ID = 1331
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.NoError(p.UpdateAccessKeyAndClearCache("ak"))
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
_, ok := cache.Get("policy_1331")
|
||||
a.False(ok)
|
||||
}
|
||||
|
||||
9
models/scripts/init.go
Normal file
9
models/scripts/init.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package scripts
|
||||
|
||||
import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
|
||||
func Init() {
|
||||
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
|
||||
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
|
||||
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
|
||||
}
|
||||
38
models/scripts/invoker/invoker.go
Normal file
38
models/scripts/invoker/invoker.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package invoker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DBScript interface {
|
||||
Run(ctx context.Context)
|
||||
}
|
||||
|
||||
var availableScripts = make(map[string]DBScript)
|
||||
|
||||
func RunDBScript(name string, ctx context.Context) error {
|
||||
if script, ok := availableScripts[name]; ok {
|
||||
util.Log().Info("开始执行数据库脚本 [%s]", name)
|
||||
script.Run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
|
||||
}
|
||||
|
||||
func Register(name string, script DBScript) {
|
||||
availableScripts[name] = script
|
||||
}
|
||||
|
||||
func ListPrefix(prefix string) []string {
|
||||
var scripts []string
|
||||
for name := range availableScripts {
|
||||
if strings.HasPrefix(name, prefix) {
|
||||
scripts = append(scripts, name)
|
||||
}
|
||||
}
|
||||
return scripts
|
||||
}
|
||||
39
models/scripts/invoker/invoker_test.go
Normal file
39
models/scripts/invoker/invoker_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package invoker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestScript int
|
||||
|
||||
func (script TestScript) Run(ctx context.Context) {
|
||||
|
||||
}
|
||||
|
||||
func TestRunDBScript(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
Register("test", TestScript(0))
|
||||
|
||||
// 不存在
|
||||
{
|
||||
asserts.Error(RunDBScript("else", context.Background()))
|
||||
}
|
||||
|
||||
// 存在
|
||||
{
|
||||
asserts.NoError(RunDBScript("test", context.Background()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPrefix(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
Register("U1", TestScript(0))
|
||||
Register("U2", TestScript(0))
|
||||
Register("U3", TestScript(0))
|
||||
Register("P1", TestScript(0))
|
||||
|
||||
res := ListPrefix("U")
|
||||
asserts.Len(res, 3)
|
||||
}
|
||||
31
models/scripts/reset.go
Normal file
31
models/scripts/reset.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
type ResetAdminPassword int
|
||||
|
||||
// Run 运行脚本从社区版升级至 Pro 版
|
||||
func (script ResetAdminPassword) Run(ctx context.Context) {
|
||||
// 查找用户
|
||||
user, err := model.GetUserByID(1)
|
||||
if err != nil {
|
||||
util.Log().Panic("初始管理员用户不存在, %s", err)
|
||||
}
|
||||
|
||||
// 生成密码
|
||||
password := util.RandStringRunes(8)
|
||||
|
||||
// 更改为新密码
|
||||
user.SetPassword(password)
|
||||
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
|
||||
util.Log().Panic("密码更改失败, %s", err)
|
||||
}
|
||||
|
||||
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
|
||||
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password))
|
||||
}
|
||||
50
models/scripts/reset_test.go
Normal file
50
models/scripts/reset_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResetAdminPassword_Run(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
script := ResetAdminPassword(0)
|
||||
|
||||
// 初始用户不存在
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}))
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 密码更新失败
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
asserts.Panics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
asserts.NotPanics(func() {
|
||||
script.Run(context.Background())
|
||||
})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
33
models/scripts/storage.go
Normal file
33
models/scripts/storage.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
type UserStorageCalibration int
|
||||
|
||||
type storageResult struct {
|
||||
Total uint64
|
||||
}
|
||||
|
||||
// Run 运行脚本校准所有用户容量
|
||||
func (script UserStorageCalibration) Run(ctx context.Context) {
|
||||
// 列出所有用户
|
||||
var res []model.User
|
||||
model.DB.Model(&model.User{}).Find(&res)
|
||||
|
||||
// 逐个检查容量
|
||||
for _, user := range res {
|
||||
// 计算正确的容量
|
||||
var total storageResult
|
||||
model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total)
|
||||
// 更新用户的容量
|
||||
if user.Storage != total.Total {
|
||||
util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email,
|
||||
user.Storage, total.Total)
|
||||
model.DB.Model(&user).Update("storage", total.Total)
|
||||
}
|
||||
}
|
||||
}
|
||||
58
models/scripts/storage_test.go
Normal file
58
models/scripts/storage_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
var mockDB *gorm.DB
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
mockDB = model.DB
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestUserStorageCalibration_Run(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
script := UserStorageCalibration(0)
|
||||
|
||||
// 容量异常
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectQuery("SELECT(.+)files(.+)").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
script.Run(context.Background())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// 容量正常
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
|
||||
mock.ExpectQuery("SELECT(.+)files(.+)").
|
||||
WithArgs(1).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
|
||||
script.Run(context.Background())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
43
models/scripts/upgrade.go
Normal file
43
models/scripts/upgrade.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type UpgradeTo340 int
|
||||
|
||||
// Run upgrade from older version to 3.4.0
|
||||
func (script UpgradeTo340) Run(ctx context.Context) {
|
||||
// 取回老版本 aria2 设定
|
||||
old := model.GetSettingByType([]string{"aria2"})
|
||||
if len(old) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 写入到新版本的节点设定
|
||||
n, err := model.GetNodeByID(1)
|
||||
if err != nil {
|
||||
util.Log().Error("找不到主机节点, %s", err)
|
||||
}
|
||||
|
||||
n.Aria2Enabled = old["aria2_rpcurl"] != ""
|
||||
n.Aria2OptionsSerialized.Options = old["aria2_options"]
|
||||
n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"]
|
||||
|
||||
interval, err := strconv.Atoi(old["aria2_interval"])
|
||||
if err != nil {
|
||||
interval = 10
|
||||
}
|
||||
n.Aria2OptionsSerialized.Interval = interval
|
||||
n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"]
|
||||
n.Aria2OptionsSerialized.Token = old["aria2_token"]
|
||||
if err := model.DB.Save(&n).Error; err != nil {
|
||||
util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err)
|
||||
} else {
|
||||
model.DB.Where("type = ?", "aria2").Delete(model.Setting{})
|
||||
util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式")
|
||||
}
|
||||
}
|
||||
66
models/scripts/upgrade_test.go
Normal file
66
models/scripts/upgrade_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgradeTo340_Run(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
script := UpgradeTo340(0)
|
||||
|
||||
// skip
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// node not found
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||
AddRow("aria2_interval", "expected_aria2_interval").
|
||||
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||
AddRow("aria2_token", "expected_aria2_token").
|
||||
AddRow("aria2_options", "{}"))
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// failed
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||
AddRow("aria2_interval", "expected_aria2_interval").
|
||||
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||
AddRow("aria2_token", "expected_aria2_token").
|
||||
AddRow("aria2_options", "{}"))
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Setting 系统设置模型
|
||||
@@ -29,12 +30,16 @@ func GetSettingByName(name string) string {
|
||||
if optionValue, ok := cache.Get(cacheKey); ok {
|
||||
return optionValue.(string)
|
||||
}
|
||||
|
||||
// 尝试数据库中查找
|
||||
result := DB.Where("name = ?", name).First(&setting)
|
||||
if result.Error == nil {
|
||||
_ = cache.Set(cacheKey, setting.Value, -1)
|
||||
return setting.Value
|
||||
if DB != nil {
|
||||
result := DB.Where("name = ?", name).First(&setting)
|
||||
if result.Error == nil {
|
||||
_ = cache.Set(cacheKey, setting.Value, -1)
|
||||
return setting.Value
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
@@ -3,13 +3,14 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/hashid"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
// Share 分享模型
|
||||
|
||||
@@ -2,15 +2,16 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShare_Create(t *testing.T) {
|
||||
@@ -187,18 +188,6 @@ func TestShare_CanBeDownloadBy(t *testing.T) {
|
||||
asserts.Error(share.CanBeDownloadBy(user))
|
||||
}
|
||||
|
||||
// 未登录,需要积分
|
||||
{
|
||||
user := &User{
|
||||
Group: Group{
|
||||
OptionsSerialized: GroupOption{
|
||||
ShareDownload: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
asserts.Error(share.CanBeDownloadBy(user))
|
||||
}
|
||||
|
||||
// 成功
|
||||
{
|
||||
user := &User{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -55,9 +55,9 @@ func TestGetTagsByUID(t *testing.T) {
|
||||
|
||||
func TestGetTagsByID(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
res, err := GetTasksByID(1)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag"))
|
||||
res, err := GetTagsByUID(1)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
asserts.EqualValues(1, res.ID)
|
||||
asserts.EqualValues("tag", res[0].Name)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -91,3 +91,14 @@ func TestListTasks(t *testing.T) {
|
||||
asserts.EqualValues(5, total)
|
||||
asserts.Len(res, 1)
|
||||
}
|
||||
|
||||
func TestGetTasksByStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").
|
||||
WithArgs(1, 2).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
res := GetTasksByStatus(1, 2)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.Len(res, 1)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +35,7 @@ type User struct {
|
||||
Storage uint64
|
||||
TwoFactor string
|
||||
Avatar string
|
||||
Options string `json:"-",gorm:"type:text"`
|
||||
Options string `json:"-" gorm:"type:text"`
|
||||
Authn string `gorm:"type:text"`
|
||||
|
||||
// 关联模型
|
||||
@@ -137,6 +139,13 @@ func GetActiveUserByOpenID(openid string) (User, error) {
|
||||
|
||||
// GetUserByEmail 用Email获取用户
|
||||
func GetUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user)
|
||||
return user, result.Error
|
||||
}
|
||||
|
||||
// GetActiveUserByEmail 用Email获取可登录用户
|
||||
func GetActiveUserByEmail(email string) (User, error) {
|
||||
var user User
|
||||
result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user)
|
||||
return user, result.Error
|
||||
@@ -191,11 +200,24 @@ func (user *User) CheckPassword(password string) (bool, error) {
|
||||
|
||||
// 根据存储密码拆分为 Salt 和 Digest
|
||||
passwordStore := strings.Split(user.Password, ":")
|
||||
if len(passwordStore) != 2 {
|
||||
if len(passwordStore) != 2 && len(passwordStore) != 3 {
|
||||
return false, errors.New("Unknown password type")
|
||||
}
|
||||
|
||||
// todo 兼容V2/V1密码
|
||||
// 兼容V2密码,升级后存储格式为: md5:$HASH:$SALT
|
||||
if len(passwordStore) == 3 {
|
||||
if passwordStore[0] != "md5" {
|
||||
return false, errors.New("Unknown password type")
|
||||
}
|
||||
hash := md5.New()
|
||||
_, err := hash.Write([]byte(passwordStore[2] + password))
|
||||
bs := hex.EncodeToString(hash.Sum(nil))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bs == passwordStore[1], nil
|
||||
}
|
||||
|
||||
//计算 Salt 和密码组合的SHA1摘要
|
||||
hash := sha1.New()
|
||||
_, err := hash.Write([]byte(password + passwordStore[0]))
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/HFO4/cloudreve/pkg/hashid"
|
||||
"github.com/duo-labs/webauthn/webauthn"
|
||||
"net/url"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/duo-labs/webauthn/webauthn"
|
||||
)
|
||||
|
||||
/*
|
||||
|
||||
@@ -2,12 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
@@ -144,6 +145,27 @@ func TestUser_CheckPassword(t *testing.T) {
|
||||
asserts.Error(err)
|
||||
asserts.False(res)
|
||||
|
||||
// 未知密码类型
|
||||
user = User{}
|
||||
user.Password = "1:2:3"
|
||||
res, err = user.CheckPassword("Cause Sega does what nintendon't")
|
||||
asserts.Error(err)
|
||||
asserts.False(res)
|
||||
|
||||
// V2密码,错误
|
||||
user = User{}
|
||||
user.Password = "md5:2:3"
|
||||
res, err = user.CheckPassword("Cause Sega does what nintendon't")
|
||||
asserts.NoError(err)
|
||||
asserts.False(res)
|
||||
|
||||
// V2密码,正确
|
||||
user = User{}
|
||||
user.Password = "md5:d8446059f8846a2c111a7f53515665fb:sdshare"
|
||||
res, err = user.CheckPassword("admin")
|
||||
asserts.NoError(err)
|
||||
asserts.True(res)
|
||||
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
@@ -155,10 +177,10 @@ func TestNewUser(t *testing.T) {
|
||||
|
||||
func TestUser_AfterFind(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
cache.Deletes([]string{"1"}, "policy_")
|
||||
cache.Deletes([]string{"0"}, "policy_")
|
||||
|
||||
policyRows := sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow(1, "默认存储策略")
|
||||
AddRow(144, "默认存储策略")
|
||||
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
|
||||
|
||||
newUser := NewUser()
|
||||
@@ -218,11 +240,6 @@ func TestUser_GetRemainingCapacity(t *testing.T) {
|
||||
newUser.Group.MaxStorage = 100
|
||||
newUser.Storage = 200
|
||||
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
|
||||
|
||||
cache.Set("pack_size_0", uint64(10), 0)
|
||||
newUser.Group.MaxStorage = 100
|
||||
newUser.Storage = 101
|
||||
asserts.Equal(uint64(9), newUser.GetRemainingCapacity())
|
||||
}
|
||||
|
||||
func TestUser_DeductionCapacity(t *testing.T) {
|
||||
@@ -258,10 +275,6 @@ func TestUser_DeductionCapacity(t *testing.T) {
|
||||
asserts.Equal(false, newUser.IncreaseStorage(1))
|
||||
asserts.Equal(uint64(100), newUser.Storage)
|
||||
|
||||
cache.Set("pack_size_1", uint64(1), 0)
|
||||
asserts.Equal(true, newUser.IncreaseStorage(1))
|
||||
asserts.Equal(uint64(101), newUser.Storage)
|
||||
|
||||
asserts.True(newUser.IncreaseStorage(0))
|
||||
}
|
||||
|
||||
@@ -330,10 +343,20 @@ func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
func TestGetActiveUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetActiveUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
|
||||
_, err := GetUserByEmail("abslant@foxmail.com")
|
||||
|
||||
asserts.Error(err)
|
||||
|
||||
@@ -55,6 +55,6 @@ func TestDeleteWebDAVAccountByID(t *testing.T) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
asserts.NoError(DeleteTagByID(1, 1))
|
||||
DeleteWebDAVAccountByID(1, 1)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
@@ -1,168 +1,67 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
)
|
||||
|
||||
// Instance 默认使用的Aria2处理实例
|
||||
var Instance Aria2 = &DummyAria2{}
|
||||
var Instance common.Aria2 = &common.DummyAria2{}
|
||||
|
||||
// LB 获取 Aria2 节点的负载均衡器
|
||||
var LB balancer.Balancer
|
||||
|
||||
// Lock Instance的读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// EventNotifier 任务状态更新通知处理器
|
||||
var EventNotifier = &Notifier{}
|
||||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options map[string]interface{}) error
|
||||
// 返回状态信息
|
||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||
// 取消任务
|
||||
Cancel(task *model.Download) error
|
||||
// 选择要下载的文件
|
||||
Select(task *model.Download, files []int) error
|
||||
}
|
||||
|
||||
const (
|
||||
// URLTask 从URL添加的任务
|
||||
URLTask = iota
|
||||
// TorrentTask 种子任务
|
||||
TorrentTask
|
||||
)
|
||||
|
||||
const (
|
||||
// Ready 准备就绪
|
||||
Ready = iota
|
||||
// Downloading 下载中
|
||||
Downloading
|
||||
// Paused 暂停中
|
||||
Paused
|
||||
// Error 出错
|
||||
Error
|
||||
// Complete 完成
|
||||
Complete
|
||||
// Canceled 取消/停止
|
||||
Canceled
|
||||
// Unknown 未知状态
|
||||
Unknown
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotEnabled 功能未开启错误
|
||||
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
||||
// ErrUserNotFound 未找到下载任务创建者
|
||||
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
|
||||
)
|
||||
|
||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||
type DummyAria2 struct {
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// Status 返回未开启错误
|
||||
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
return rpc.StatusInfo{}, ErrNotEnabled
|
||||
}
|
||||
|
||||
// Cancel 返回未开启错误
|
||||
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// Select 返回未开启错误
|
||||
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||
return ErrNotEnabled
|
||||
// GetLoadBalancer 返回供Aria2使用的负载均衡器
|
||||
func GetLoadBalancer() balancer.Balancer {
|
||||
Lock.RLock()
|
||||
defer Lock.RUnlock()
|
||||
return LB
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func Init(isReload bool) {
|
||||
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
// 关闭上个初始连接
|
||||
if previousClient, ok := Instance.(*RPCService); ok {
|
||||
if previousClient.Caller != nil {
|
||||
util.Log().Debug("关闭上个 aria2 连接")
|
||||
previousClient.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
options := model.GetSettingByNames("aria2_rpcurl", "aria2_token", "aria2_options")
|
||||
timeout := model.GetIntSetting("aria2_call_timeout", 5)
|
||||
if options["aria2_rpcurl"] == "" {
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
util.Log().Info("初始化 aria2 RPC 服务[%s]", options["aria2_rpcurl"])
|
||||
client := &RPCService{}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(options["aria2_rpcurl"])
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 RPC 服务地址,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
err = json.Unmarshal([]byte(options["aria2_options"]), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析 aria2 全局配置,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.Init(server.String(), options["aria2_token"], timeout, globalOptions); err != nil {
|
||||
util.Log().Warning("初始化 aria2 RPC 服务失败,%s", err)
|
||||
Instance = &DummyAria2{}
|
||||
return
|
||||
}
|
||||
|
||||
Instance = client
|
||||
LB = balancer.NewBalancer("RoundRobin")
|
||||
Lock.Unlock()
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(Ready, Paused, Downloading)
|
||||
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
NewMonitor(&unfinished[i])
|
||||
monitor.NewMonitor(&unfinished[i], pool, mqClient)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// getStatus 将给定的状态字符串转换为状态标识数字
|
||||
func getStatus(status string) int {
|
||||
switch status {
|
||||
case "complete":
|
||||
return Complete
|
||||
case "active":
|
||||
return Downloading
|
||||
case "waiting":
|
||||
return Ready
|
||||
case "paused":
|
||||
return Paused
|
||||
case "error":
|
||||
return Error
|
||||
case "removed":
|
||||
return Canceled
|
||||
default:
|
||||
return Unknown
|
||||
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
|
||||
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
|
||||
// 解析RPC服务地址
|
||||
rpcServer, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
|
||||
}
|
||||
|
||||
rpcServer.Path = "/jsonrpc"
|
||||
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
|
||||
}
|
||||
|
||||
return caller.GetVersion()
|
||||
}
|
||||
|
||||
@@ -2,12 +2,15 @@ package aria2
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
@@ -25,66 +28,39 @@ func TestMain(m *testing.M) {
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestDummyAria2(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
instance := DummyAria2{}
|
||||
asserts.Error(instance.CreateTask(nil, nil))
|
||||
_, err := instance.Status(nil)
|
||||
asserts.Error(err)
|
||||
asserts.Error(instance.Cancel(nil))
|
||||
asserts.Error(instance.Select(nil, nil))
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
MAX_RETRY = 0
|
||||
asserts := assert.New(t)
|
||||
cache.Set("setting_aria2_token", "1", 0)
|
||||
cache.Set("setting_aria2_call_timeout", "5", 0)
|
||||
cache.Set("setting_aria2_options", `[]`, 0)
|
||||
a := assert.New(t)
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
|
||||
mockQueue := mq.NewMQ()
|
||||
|
||||
// 未指定RPC地址,跳过
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
Init(false, mockPool, mockQueue)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTestRPCConnection(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
// url not legal
|
||||
{
|
||||
cache.Set("setting_aria2_rpcurl", "", 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
|
||||
// 无法解析服务器地址
|
||||
// rpc failed
|
||||
{
|
||||
cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
}
|
||||
|
||||
// 无法解析全局配置
|
||||
{
|
||||
Instance = &RPCService{}
|
||||
cache.Set("setting_aria2_options", "?", 0)
|
||||
cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0)
|
||||
Init(false)
|
||||
asserts.IsType(&DummyAria2{}, Instance)
|
||||
}
|
||||
|
||||
// 连接失败
|
||||
{
|
||||
cache.Set("setting_aria2_options", "{}", 0)
|
||||
cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0)
|
||||
cache.Set("setting_aria2_call_timeout", "1", 0)
|
||||
cache.Set("setting_aria2_interval", "100", 0)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
|
||||
Init(false)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.IsType(&RPCService{}, Instance)
|
||||
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
|
||||
a.Error(err)
|
||||
a.Empty(res.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
asserts.Equal(4, getStatus("complete"))
|
||||
asserts.Equal(1, getStatus("active"))
|
||||
asserts.Equal(0, getStatus("waiting"))
|
||||
asserts.Equal(2, getStatus("paused"))
|
||||
asserts.Equal(3, getStatus("error"))
|
||||
asserts.Equal(5, getStatus("removed"))
|
||||
asserts.Equal(6, getStatus("?"))
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.NotPanics(func() {
|
||||
GetLoadBalancer()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type RPCService struct {
|
||||
options *clientOptions
|
||||
Caller rpc.Client
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error {
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if client.Caller != nil {
|
||||
client.Caller.Close()
|
||||
}
|
||||
|
||||
client.options = &clientOptions{
|
||||
Options: options,
|
||||
}
|
||||
caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second,
|
||||
EventNotifier)
|
||||
client.Caller = caller
|
||||
return err
|
||||
}
|
||||
|
||||
// Status 查询下载状态
|
||||
func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := client.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err)
|
||||
time.Sleep(time.Duration(10) * time.Second)
|
||||
res, err = client.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Cancel 取消下载
|
||||
func (client *RPCService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := client.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
|
||||
}
|
||||
|
||||
//// 删除临时文件
|
||||
//util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID)
|
||||
//go func(task *model.Download) {
|
||||
// select {
|
||||
// case <-time.After(time.Duration(60) * time.Second):
|
||||
// err := os.RemoveAll(task.Parent)
|
||||
// if err != nil {
|
||||
// util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err)
|
||||
// }
|
||||
// }
|
||||
//}(task)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Select 选取要下载的文件
|
||||
func (client *RPCService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务
|
||||
func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error {
|
||||
// 生成存储路径
|
||||
path := filepath.Join(
|
||||
model.GetSettingByName("aria2_temp_path"),
|
||||
"aria2",
|
||||
strconv.FormatInt(time.Now().UnixNano(), 10),
|
||||
)
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range client.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := client.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
task.GID = gid
|
||||
_, err = task.Create()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建任务监控
|
||||
NewMonitor(task)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRPCService_Init(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.Error(caller.Init("ws://", "", 1, nil))
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
}
|
||||
|
||||
func TestRPCService_Status(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
_, err := caller.Status(&model.Download{})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_Cancel(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
err := caller.Cancel(&model.Download{Parent: "test"})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_Select(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
|
||||
err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3})
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
func TestRPCService_CreateTask(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
caller := &RPCService{}
|
||||
asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil))
|
||||
cache.Set("setting_aria2_temp_path", "test", 0)
|
||||
err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"})
|
||||
asserts.Error(err)
|
||||
}
|
||||
114
pkg/aria2/common/common.go
Normal file
114
pkg/aria2/common/common.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// Init 初始化客户端连接
|
||||
Init() error
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
|
||||
// 返回状态信息
|
||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||
// 取消任务
|
||||
Cancel(task *model.Download) error
|
||||
// 选择要下载的文件
|
||||
Select(task *model.Download, files []int) error
|
||||
// 获取离线下载配置
|
||||
GetConfig() model.Aria2Option
|
||||
// 删除临时下载文件
|
||||
DeleteTempFile(*model.Download) error
|
||||
}
|
||||
|
||||
const (
|
||||
// URLTask 从URL添加的任务
|
||||
URLTask = iota
|
||||
// TorrentTask 种子任务
|
||||
TorrentTask
|
||||
)
|
||||
|
||||
const (
|
||||
// Ready 准备就绪
|
||||
Ready = iota
|
||||
// Downloading 下载中
|
||||
Downloading
|
||||
// Paused 暂停中
|
||||
Paused
|
||||
// Error 出错
|
||||
Error
|
||||
// Complete 完成
|
||||
Complete
|
||||
// Canceled 取消/停止
|
||||
Canceled
|
||||
// Unknown 未知状态
|
||||
Unknown
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotEnabled 功能未开启错误
|
||||
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
|
||||
// ErrUserNotFound 未找到下载任务创建者
|
||||
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
|
||||
)
|
||||
|
||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||
type DummyAria2 struct {
|
||||
}
|
||||
|
||||
func (instance *DummyAria2) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
|
||||
return "", ErrNotEnabled
|
||||
}
|
||||
|
||||
// Status 返回未开启错误
|
||||
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
return rpc.StatusInfo{}, ErrNotEnabled
|
||||
}
|
||||
|
||||
// Cancel 返回未开启错误
|
||||
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// Select 返回未开启错误
|
||||
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) GetConfig() model.Aria2Option {
|
||||
return model.Aria2Option{}
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetStatus 将给定的状态字符串转换为状态标识数字
|
||||
func GetStatus(status string) int {
|
||||
switch status {
|
||||
case "complete":
|
||||
return Complete
|
||||
case "active":
|
||||
return Downloading
|
||||
case "waiting":
|
||||
return Ready
|
||||
case "paused":
|
||||
return Paused
|
||||
case "error":
|
||||
return Error
|
||||
case "removed":
|
||||
return Canceled
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
}
|
||||
45
pkg/aria2/common/common_test.go
Normal file
45
pkg/aria2/common/common_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDummyAria2(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
d := &DummyAria2{}
|
||||
|
||||
a.NoError(d.Init())
|
||||
|
||||
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
|
||||
a.Empty(res)
|
||||
a.Error(err)
|
||||
|
||||
_, err = d.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Cancel(&model.Download{})
|
||||
a.Error(err)
|
||||
|
||||
err = d.Select(&model.Download{}, []int{})
|
||||
a.Error(err)
|
||||
|
||||
configRes := d.GetConfig()
|
||||
a.NotNil(configRes)
|
||||
|
||||
err = d.DeleteTempFile(&model.Download{})
|
||||
a.Error(err)
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(GetStatus("complete"), Complete)
|
||||
a.Equal(GetStatus("active"), Downloading)
|
||||
a.Equal(GetStatus("waiting"), Ready)
|
||||
a.Equal(GetStatus("paused"), Paused)
|
||||
a.Equal(GetStatus("error"), Error)
|
||||
a.Equal(GetStatus("removed"), Canceled)
|
||||
a.Equal(GetStatus("unknown"), Unknown)
|
||||
}
|
||||
@@ -1,20 +1,23 @@
|
||||
package aria2
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/driver/local"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||
"github.com/HFO4/cloudreve/pkg/task"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Monitor 离线下载状态监控
|
||||
@@ -22,35 +25,37 @@ type Monitor struct {
|
||||
Task *model.Download
|
||||
Interval time.Duration
|
||||
|
||||
notifier chan StatusEvent
|
||||
notifier <-chan mq.Message
|
||||
node cluster.Node
|
||||
retried int
|
||||
}
|
||||
|
||||
// StatusEvent 状态改变事件
|
||||
type StatusEvent struct {
|
||||
GID string
|
||||
Status int
|
||||
}
|
||||
|
||||
var MAX_RETRY = 10
|
||||
|
||||
// NewMonitor 新建上传状态监控
|
||||
func NewMonitor(task *model.Download) {
|
||||
// NewMonitor 新建离线下载状态监控
|
||||
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
|
||||
monitor := &Monitor{
|
||||
Task: task,
|
||||
Interval: time.Duration(model.GetIntSetting("aria2_interval", 10)) * time.Second,
|
||||
notifier: make(chan StatusEvent),
|
||||
notifier: make(chan mq.Message),
|
||||
node: pool.GetNodeByID(task.GetNodeID()),
|
||||
}
|
||||
|
||||
if monitor.node != nil {
|
||||
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
|
||||
go monitor.Loop(mqClient)
|
||||
|
||||
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
|
||||
} else {
|
||||
monitor.setErrorStatus(errors.New("节点不可用"))
|
||||
}
|
||||
go monitor.Loop()
|
||||
EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
|
||||
}
|
||||
|
||||
// Loop 开启监控循环
|
||||
func (monitor *Monitor) Loop() {
|
||||
defer EventNotifier.Unsubscribe(monitor.Task.GID)
|
||||
func (monitor *Monitor) Loop(mqClient mq.MQ) {
|
||||
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
|
||||
|
||||
// 首次循环立即更新
|
||||
interval := time.Duration(0)
|
||||
interval := 50 * time.Millisecond
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -69,9 +74,7 @@ func (monitor *Monitor) Loop() {
|
||||
|
||||
// Update 更新状态,返回值表示是否退出监控
|
||||
func (monitor *Monitor) Update() bool {
|
||||
Lock.RLock()
|
||||
status, err := Instance.Status(monitor.Task)
|
||||
Lock.RUnlock()
|
||||
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
|
||||
|
||||
if err != nil {
|
||||
monitor.retried++
|
||||
@@ -101,6 +104,7 @@ func (monitor *Monitor) Update() bool {
|
||||
if err := monitor.UpdateTaskInfo(status); err != nil {
|
||||
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err)
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -108,13 +112,13 @@ func (monitor *Monitor) Update() bool {
|
||||
|
||||
switch status.Status {
|
||||
case "complete":
|
||||
return monitor.Complete(status)
|
||||
return monitor.Complete(task.TaskPoll)
|
||||
case "error":
|
||||
return monitor.Error(status)
|
||||
case "active", "waiting", "paused":
|
||||
return false
|
||||
case "removed":
|
||||
monitor.Task.Status = Canceled
|
||||
monitor.Task.Status = common.Canceled
|
||||
monitor.Task.Save()
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
@@ -129,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||
originSize := monitor.Task.TotalSize
|
||||
|
||||
monitor.Task.GID = status.Gid
|
||||
monitor.Task.Status = getStatus(status.Status)
|
||||
monitor.Task.Status = common.GetStatus(status.Status)
|
||||
|
||||
// 文件大小、已下载大小
|
||||
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
|
||||
@@ -163,9 +167,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||
// 文件大小更新后,对文件限制等进行校验
|
||||
if err := monitor.ValidateFile(); err != nil {
|
||||
// 验证失败时取消任务
|
||||
Lock.RLock()
|
||||
Instance.Cancel(monitor.Task)
|
||||
Lock.RUnlock()
|
||||
monitor.node.GetAria2Instance().Cancel(monitor.Task)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -178,7 +180,7 @@ func (monitor *Monitor) ValidateFile() error {
|
||||
// 找到任务创建者
|
||||
user := monitor.Task.GetOwner()
|
||||
if user == nil {
|
||||
return ErrUserNotFound
|
||||
return common.ErrUserNotFound
|
||||
}
|
||||
|
||||
// 创建文件系统
|
||||
@@ -229,35 +231,40 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
|
||||
|
||||
// RemoveTempFolder 清理下载临时目录
|
||||
func (monitor *Monitor) RemoveTempFolder() {
|
||||
err := os.RemoveAll(monitor.Task.Parent)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err)
|
||||
}
|
||||
|
||||
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
|
||||
}
|
||||
|
||||
// Complete 完成下载,返回是否中断监控
|
||||
func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
|
||||
func (monitor *Monitor) Complete(pool task.Pool) bool {
|
||||
// 创建中转任务
|
||||
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
|
||||
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
|
||||
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
|
||||
if monitor.Task.StatusInfo.Files[i].Selected == "true" {
|
||||
file = append(file, monitor.Task.StatusInfo.Files[i].Path)
|
||||
fileInfo := monitor.Task.StatusInfo.Files[i]
|
||||
if fileInfo.Selected == "true" {
|
||||
file = append(file, fileInfo.Path)
|
||||
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||
sizes[fileInfo.Path] = size
|
||||
}
|
||||
}
|
||||
|
||||
job, err := task.NewTransferTask(
|
||||
monitor.Task.UserID,
|
||||
file,
|
||||
monitor.Task.Dst,
|
||||
monitor.Task.Parent,
|
||||
true,
|
||||
monitor.node.ID(),
|
||||
sizes,
|
||||
)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 提交中转任务
|
||||
task.TaskPoll.Submit(job)
|
||||
pool.Submit(job)
|
||||
|
||||
// 更新任务ID
|
||||
monitor.Task.TaskID = job.Model().ID
|
||||
@@ -267,7 +274,7 @@ func (monitor *Monitor) Complete(status rpc.StatusInfo) bool {
|
||||
}
|
||||
|
||||
func (monitor *Monitor) setErrorStatus(err error) {
|
||||
monitor.Task.Status = Error
|
||||
monitor.Task.Status = common.Error
|
||||
monitor.Task.Error = err.Error()
|
||||
monitor.Task.Save()
|
||||
}
|
||||
438
pkg/aria2/monitor/monitor_test.go
Normal file
438
pkg/aria2/monitor/monitor_test.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestNewMonitor(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
|
||||
// node not available
|
||||
{
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockPool.AssertExpectations(t)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(task.Error)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
mockPool := &mocks.NodePoolMock{}
|
||||
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
|
||||
|
||||
task := &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
}
|
||||
NewMonitor(task, mockPool, mockMQ)
|
||||
mockNode.AssertExpectations(t)
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMonitor_Loop(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockMQ := mq.NewMQ()
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
retried: MAX_RETRY,
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
notifier: mockMQ.Subscribe("test", 1),
|
||||
}
|
||||
|
||||
// into interval loop
|
||||
{
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
// into notifier loop
|
||||
{
|
||||
m.Task.Error = ""
|
||||
mockMQ.Publish("test", mq.Message{})
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
m.Loop(mockMQ)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
for i := 0; i < MAX_RETRY; i++ {
|
||||
a.False(m.Update())
|
||||
}
|
||||
|
||||
mockNode.AssertExpectations(t)
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"next"},
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
a.Equal("next", m.Task.GID)
|
||||
mockAria2.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateCompleted(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "complete",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
mockNode.On("ID").Return(uint(1))
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateError(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "error",
|
||||
ErrorMessage: "error",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
a.NotEmpty(m.Task.Error)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateActive(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "active",
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.False(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateRemoved(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "removed",
|
||||
}, nil)
|
||||
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.Equal(common.Canceled, m.Task.Status)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateUnknown(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockAria2 := &mocks.Aria2Mock{}
|
||||
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
Status: "unknown",
|
||||
}, nil)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(mockAria2)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Update())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockAria2.AssertExpectations(t)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
status := rpc.StatusInfo{
|
||||
Status: "completed",
|
||||
TotalLength: "100",
|
||||
CompletedLength: "50",
|
||||
DownloadSpeed: "20",
|
||||
}
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{Model: gorm.Model{ID: 1}},
|
||||
}
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
err := m.UpdateTaskInfo(status)
|
||||
a.Error(err)
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestMonitor_ValidateFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
TotalSize: 100,
|
||||
},
|
||||
}
|
||||
|
||||
// failed to create filesystem
|
||||
{
|
||||
m.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "random",
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile())
|
||||
}
|
||||
|
||||
// User capacity not enough
|
||||
{
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 99,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile())
|
||||
}
|
||||
|
||||
// single file too big
|
||||
{
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 100,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
MaxSize: 99,
|
||||
},
|
||||
}
|
||||
a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile())
|
||||
}
|
||||
|
||||
// all pass
|
||||
{
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
m.Task.User = &model.User{
|
||||
Group: model.Group{
|
||||
MaxStorage: 100,
|
||||
},
|
||||
Policy: model.Policy{
|
||||
Type: "local",
|
||||
MaxSize: 100,
|
||||
},
|
||||
}
|
||||
a.NoError(m.ValidateFile())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_Complete(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &mocks.NodeMock{}
|
||||
mockNode.On("ID").Return(uint(1))
|
||||
mockPool := &mocks.TaskPoolMock{}
|
||||
mockPool.On("Submit", testMock.Anything)
|
||||
m := &Monitor{
|
||||
node: mockNode,
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
TotalSize: 100,
|
||||
UserID: 9414,
|
||||
},
|
||||
}
|
||||
m.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Length: "100",
|
||||
Selected: "true",
|
||||
},
|
||||
}
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
a.True(m.Complete(mockPool))
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
mockNode.AssertExpectations(t)
|
||||
mockPool.AssertExpectations(t)
|
||||
}
|
||||
@@ -1,323 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem"
|
||||
"github.com/HFO4/cloudreve/pkg/task"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type InstanceMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m InstanceMock) CreateTask(task *model.Download, options map[string]interface{}) error {
|
||||
args := m.Called(task, options)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m InstanceMock) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
args := m.Called(task)
|
||||
return args.Get(0).(rpc.StatusInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (m InstanceMock) Cancel(task *model.Download) error {
|
||||
args := m.Called(task)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m InstanceMock) Select(task *model.Download, files []int) error {
|
||||
args := m.Called(task, files)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewMonitor(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
NewMonitor(&model.Download{GID: "gid"})
|
||||
_, ok := EventNotifier.Subscribes.Load("gid")
|
||||
asserts.True(ok)
|
||||
}
|
||||
|
||||
func TestMonitor_Loop(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
notifier := make(chan StatusEvent)
|
||||
MAX_RETRY = 0
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{GID: "gid"},
|
||||
Interval: time.Duration(1) * time.Second,
|
||||
notifier: notifier,
|
||||
}
|
||||
asserts.NotPanics(func() {
|
||||
monitor.Loop()
|
||||
})
|
||||
}
|
||||
|
||||
func TestMonitor_Update(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_Update",
|
||||
},
|
||||
Interval: time.Duration(1) * time.Second,
|
||||
}
|
||||
|
||||
// 无法获取状态
|
||||
{
|
||||
MAX_RETRY = 1
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
||||
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
|
||||
file.Close()
|
||||
Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.True(monitor.Update())
|
||||
testInstance.AssertExpectations(t)
|
||||
asserts.False(util.Exists("TestMonitor_Update"))
|
||||
}
|
||||
|
||||
// 磁力链下载重定向
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"1"},
|
||||
}, nil)
|
||||
monitor.Task.ID = 1
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
asserts.EqualValues("1", monitor.Task.GID)
|
||||
}
|
||||
|
||||
// 无法更新任务信息
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
monitor.Task.ID = 1
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回未知状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回被取消状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回活跃状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.False(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回错误状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// 返回完成
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
Instance = testInstance
|
||||
asserts.True(monitor.Update())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_UpdateTaskInfo(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_UpdateTaskInfo",
|
||||
},
|
||||
}
|
||||
|
||||
// 失败
|
||||
{
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Error(err)
|
||||
}
|
||||
|
||||
// 更新成功,无需校验
|
||||
{
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
}
|
||||
|
||||
// 更新成功,大小改变,需要校验,校验失败
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Cancel", testMock.Anything).Return(nil)
|
||||
Instance = testInstance
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
err := monitor.UpdateTaskInfo(rpc.StatusInfo{TotalLength: "1"})
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Error(err)
|
||||
testInstance.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_ValidateFile(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_ValidateFile",
|
||||
},
|
||||
}
|
||||
|
||||
// 无法创建文件系统
|
||||
{
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "unknown",
|
||||
},
|
||||
}
|
||||
asserts.Error(monitor.ValidateFile())
|
||||
}
|
||||
|
||||
// 文件大小超出容量配额
|
||||
{
|
||||
cache.Set("pack_size_0", uint64(0), 0)
|
||||
monitor.Task.TotalSize = 11
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "mock",
|
||||
},
|
||||
Group: model.Group{
|
||||
MaxStorage: 10,
|
||||
},
|
||||
}
|
||||
asserts.Equal(filesystem.ErrInsufficientCapacity, monitor.ValidateFile())
|
||||
}
|
||||
|
||||
// 单文件大小超出容量配额
|
||||
{
|
||||
cache.Set("pack_size_0", uint64(0), 0)
|
||||
monitor.Task.TotalSize = 10
|
||||
monitor.Task.StatusInfo.Files = []rpc.FileInfo{
|
||||
{
|
||||
Selected: "true",
|
||||
Length: "6",
|
||||
},
|
||||
}
|
||||
monitor.Task.User = &model.User{
|
||||
Policy: model.Policy{
|
||||
Type: "mock",
|
||||
MaxSize: 5,
|
||||
},
|
||||
Group: model.Group{
|
||||
MaxStorage: 10,
|
||||
},
|
||||
}
|
||||
asserts.Equal(filesystem.ErrFileSizeTooBig, monitor.ValidateFile())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMonitor_Complete(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
monitor := &Monitor{
|
||||
Task: &model.Download{
|
||||
Model: gorm.Model{ID: 1},
|
||||
GID: "gid",
|
||||
Parent: "TestMonitor_Complete",
|
||||
StatusInfo: rpc.StatusInfo{
|
||||
Files: []rpc.FileInfo{
|
||||
{
|
||||
Selected: "true",
|
||||
Path: "TestMonitor_Complete",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cache.Set("setting_max_worker_num", "1", 0)
|
||||
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
task.Init()
|
||||
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
asserts.True(monitor.Complete(rpc.StatusInfo{}))
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Notifier aria2实践通知处理
|
||||
type Notifier struct {
|
||||
Subscribes sync.Map
|
||||
}
|
||||
|
||||
// Subscribe 订阅事件通知
|
||||
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
|
||||
notifier.Subscribes.Store(gid, target)
|
||||
}
|
||||
|
||||
// Unsubscribe 取消订阅事件通知
|
||||
func (notifier *Notifier) Unsubscribe(gid string) {
|
||||
notifier.Subscribes.Delete(gid)
|
||||
}
|
||||
|
||||
// Notify 发送通知
|
||||
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
|
||||
for _, event := range events {
|
||||
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
|
||||
target.(chan StatusEvent) <- StatusEvent{
|
||||
GID: event.Gid,
|
||||
Status: status,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnDownloadStart 下载开始
|
||||
func (notifier *Notifier) OnDownloadStart(events []rpc.Event) {
|
||||
notifier.Notify(events, Downloading)
|
||||
}
|
||||
|
||||
// OnDownloadPause 下载暂停
|
||||
func (notifier *Notifier) OnDownloadPause(events []rpc.Event) {
|
||||
notifier.Notify(events, Paused)
|
||||
}
|
||||
|
||||
// OnDownloadStop 下载停止
|
||||
func (notifier *Notifier) OnDownloadStop(events []rpc.Event) {
|
||||
notifier.Notify(events, Canceled)
|
||||
}
|
||||
|
||||
// OnDownloadComplete 下载完成
|
||||
func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) {
|
||||
notifier.Notify(events, Complete)
|
||||
}
|
||||
|
||||
// OnDownloadError 下载出错
|
||||
func (notifier *Notifier) OnDownloadError(events []rpc.Event) {
|
||||
notifier.Notify(events, Error)
|
||||
}
|
||||
|
||||
// OnBtDownloadComplete BT下载完成
|
||||
func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) {
|
||||
notifier.Notify(events, Complete)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/aria2/rpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNotifier_Notify(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
notifier2 := &Notifier{}
|
||||
notifyChan := make(chan StatusEvent, 10)
|
||||
notifier2.Subscribe(notifyChan, "1")
|
||||
|
||||
// 未订阅
|
||||
{
|
||||
notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1)
|
||||
asserts.Len(notifyChan, 0)
|
||||
}
|
||||
|
||||
// 订阅
|
||||
{
|
||||
notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1)
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
|
||||
notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}})
|
||||
asserts.Len(notifyChan, 1)
|
||||
<-notifyChan
|
||||
}
|
||||
}
|
||||
@@ -2,20 +2,25 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
|
||||
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
|
||||
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||
)
|
||||
|
||||
// General 通用的认证接口
|
||||
@@ -29,9 +34,8 @@ type Auth interface {
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
|
||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,如果请求Header中
|
||||
// 包含 X-Policy, 则此请求会被认定为上传请求,只会对URI部分和
|
||||
// Policy部分进行签名。其他请求则会对URI和Body部分进行签名。
|
||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
// 处理有效期
|
||||
if expires > 0 {
|
||||
@@ -53,27 +57,38 @@ func CheckRequest(instance Auth, r *http.Request) error {
|
||||
ok bool
|
||||
)
|
||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||
return ErrAuthFailed
|
||||
return ErrAuthHeaderMissing
|
||||
}
|
||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||
|
||||
return instance.Check(getSignContent(r), sign[0])
|
||||
}
|
||||
|
||||
// getSignContent 根据请求Header中是否包含X-Policy判断是否为上传请求,
|
||||
// 返回待签名/验证的字符串
|
||||
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果 Header 中包含 `X-Policy`,
|
||||
// 则不对正文签名。返回待签名/验证的字符串
|
||||
func getSignContent(r *http.Request) (rawSignString string) {
|
||||
if policy, ok := r.Header["X-Policy"]; ok {
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, policy[0], "")
|
||||
} else {
|
||||
var body = []byte{}
|
||||
// 读取所有body正文
|
||||
var body = []byte{}
|
||||
if _, ok := r.Header["X-Cr-Policy"]; !ok {
|
||||
if r.Body != nil {
|
||||
body, _ = ioutil.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, "", string(body))
|
||||
}
|
||||
|
||||
// 决定要签名的header
|
||||
var signedHeader []string
|
||||
for k, _ := range r.Header {
|
||||
if strings.HasPrefix(k, "X-Cr-") && k != "X-Cr-Filename" {
|
||||
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
sort.Strings(signedHeader)
|
||||
|
||||
// 读取所有待签名Header
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||
|
||||
return rawSignString
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSignURI(t *testing.T) {
|
||||
@@ -69,7 +70,7 @@ func TestSignRequest(t *testing.T) {
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Policy"] = []string{"I am Policy"}
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 10)
|
||||
asserts.NotEmpty(req.Header["Authorization"])
|
||||
}
|
||||
@@ -79,6 +80,19 @@ func TestCheckRequest(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 缺少请求头
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
"POST",
|
||||
"http://127.0.0.1/api/v3/upload",
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.Error(err)
|
||||
asserts.Equal(ErrAuthHeaderMissing, err)
|
||||
}
|
||||
|
||||
// 非上传请求 验证成功
|
||||
{
|
||||
req, err := http.NewRequest(
|
||||
@@ -100,7 +114,7 @@ func TestCheckRequest(t *testing.T) {
|
||||
strings.NewReader("I am body."),
|
||||
)
|
||||
asserts.NoError(err)
|
||||
req.Header["X-Policy"] = []string{"I am Policy"}
|
||||
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
|
||||
req = SignRequest(General, req, 0)
|
||||
err = CheckRequest(General, req)
|
||||
asserts.NoError(err)
|
||||
|
||||
@@ -33,7 +33,7 @@ func (auth HMACAuth) Check(body string, sign string) error {
|
||||
signSlice := strings.Split(sign, ":")
|
||||
// 如果未携带expires字段
|
||||
if signSlice[len(signSlice)-1] == "" {
|
||||
return ErrAuthFailed
|
||||
return ErrExpiresMissing
|
||||
}
|
||||
|
||||
// 验证是否过期
|
||||
|
||||
@@ -3,15 +3,16 @@ package auth
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/duo-labs/webauthn/webauthn"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
|
||||
15
pkg/balancer/balancer.go
Normal file
15
pkg/balancer/balancer.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package balancer
|
||||
|
||||
type Balancer interface {
|
||||
NextPeer(nodes interface{}) (error, interface{})
|
||||
}
|
||||
|
||||
// NewBalancer 根据策略标识返回新的负载均衡器
|
||||
func NewBalancer(strategy string) Balancer {
|
||||
switch strategy {
|
||||
case "RoundRobin":
|
||||
return &RoundRobin{}
|
||||
default:
|
||||
return &RoundRobin{}
|
||||
}
|
||||
}
|
||||
12
pkg/balancer/balancer_test.go
Normal file
12
pkg/balancer/balancer_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewBalancer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.NotNil(NewBalancer(""))
|
||||
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
|
||||
}
|
||||
8
pkg/balancer/errors.go
Normal file
8
pkg/balancer/errors.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package balancer
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInputNotSlice = errors.New("Input value is not silice")
|
||||
ErrNoAvaliableNode = errors.New("No nodes avaliable")
|
||||
)
|
||||
30
pkg/balancer/roundrobin.go
Normal file
30
pkg/balancer/roundrobin.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type RoundRobin struct {
|
||||
current uint64
|
||||
}
|
||||
|
||||
// NextPeer 返回轮盘的下一节点
|
||||
func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
|
||||
v := reflect.ValueOf(nodes)
|
||||
if v.Kind() != reflect.Slice {
|
||||
return ErrInputNotSlice, nil
|
||||
}
|
||||
|
||||
if v.Len() == 0 {
|
||||
return ErrNoAvaliableNode, nil
|
||||
}
|
||||
|
||||
next := r.NextIndex(v.Len())
|
||||
return nil, v.Index(next).Interface()
|
||||
}
|
||||
|
||||
// NextIndex 返回下一个节点下标
|
||||
func (r *RoundRobin) NextIndex(total int) int {
|
||||
return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total))
|
||||
}
|
||||
42
pkg/balancer/roundrobin_test.go
Normal file
42
pkg/balancer/roundrobin_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRoundRobin_NextIndex(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := &RoundRobin{}
|
||||
total := 5
|
||||
for i := 1; i < total; i++ {
|
||||
a.Equal(i, r.NextIndex(total))
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
a.Equal(i, r.NextIndex(total))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobin_NextPeer(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := &RoundRobin{}
|
||||
|
||||
// not slice
|
||||
{
|
||||
err, _ := r.NextPeer("s")
|
||||
a.Equal(ErrInputNotSlice, err)
|
||||
}
|
||||
|
||||
// no nodes
|
||||
{
|
||||
err, _ := r.NextPeer([]string{})
|
||||
a.Equal(ErrNoAvaliableNode, err)
|
||||
}
|
||||
|
||||
// pass
|
||||
{
|
||||
err, res := r.NextPeer([]string{"a"})
|
||||
a.NoError(err)
|
||||
a.Equal("a", res.(string))
|
||||
}
|
||||
}
|
||||
4
pkg/cache/driver.go
vendored
4
pkg/cache/driver.go
vendored
@@ -1,7 +1,7 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ func Init() {
|
||||
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
|
||||
Store = NewRedisStore(
|
||||
10,
|
||||
"tcp",
|
||||
conf.RedisConfig.Network,
|
||||
conf.RedisConfig.Server,
|
||||
conf.RedisConfig.Password,
|
||||
conf.RedisConfig.DB,
|
||||
|
||||
3
pkg/cache/memo.go
vendored
3
pkg/cache/memo.go
vendored
@@ -1,9 +1,10 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// MemoStore 内存存储驱动
|
||||
|
||||
5
pkg/cache/redis.go
vendored
5
pkg/cache/redis.go
vendored
@@ -3,10 +3,11 @@ package cache
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
// RedisStore redis存储驱动
|
||||
|
||||
209
pkg/cluster/controller.go
Normal file
209
pkg/cluster/controller.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DefaultController Controller
|
||||
|
||||
// Controller controls communications between master and slave
|
||||
type Controller interface {
|
||||
// Handle heartbeat sent from master
|
||||
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||
|
||||
// Get Aria2 Instance by master node ID
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
|
||||
// Send event change message to master node
|
||||
SendNotification(string, string, mq.Message) error
|
||||
|
||||
// Submit async task into task pool
|
||||
SubmitTask(string, interface{}, string, func(interface{})) error
|
||||
|
||||
// Get master node info
|
||||
GetMasterInfo(string) (*MasterInfo, error)
|
||||
|
||||
// Get master OneDrive policy credential
|
||||
GetOneDriveToken(string, uint) (string, error)
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
masters map[string]MasterInfo
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// info of master node
|
||||
type MasterInfo struct {
|
||||
ID string
|
||||
TTL int
|
||||
URL *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
Instance Node
|
||||
Client request.Client
|
||||
|
||||
jobTracker map[string]bool
|
||||
}
|
||||
|
||||
func InitController() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
||||
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
req.Node.AfterFind()
|
||||
|
||||
// close old node if exist
|
||||
origin, ok := c.masters[req.SiteID]
|
||||
|
||||
if (ok && req.IsUpdate) || !ok {
|
||||
if ok {
|
||||
origin.Instance.Kill()
|
||||
}
|
||||
|
||||
masterUrl, err := url.Parse(req.SiteURL)
|
||||
if err != nil {
|
||||
return serializer.NodePingResp{}, err
|
||||
}
|
||||
|
||||
c.masters[req.SiteID] = MasterInfo{
|
||||
ID: req.SiteID,
|
||||
URL: masterUrl,
|
||||
TTL: req.CredentialTTL,
|
||||
Client: request.NewClient(
|
||||
request.WithEndpoint(masterUrl.String()),
|
||||
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
}, int64(req.CredentialTTL)),
|
||||
),
|
||||
jobTracker: make(map[string]bool),
|
||||
Instance: NewNodeFromDBModel(&model.Node{
|
||||
Model: gorm.Model{ID: req.Node.ID},
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return node.Instance.GetAria2Instance(), nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
body := bytes.Buffer{}
|
||||
enc := gob.NewEncoder(&body)
|
||||
if err := enc.Encode(&msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"PUT",
|
||||
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
|
||||
&body,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
if _, ok := node.jobTracker[hash]; ok {
|
||||
// 任务已存在,直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
node.jobTracker[hash] = true
|
||||
submitter(job)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetMasterInfo 获取主机节点信息
|
||||
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetOneDriveToken 获取主机OneDrive凭证
|
||||
func (c *slaveController) GetOneDriveToken(id string, policyID uint) (string, error) {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"GET",
|
||||
fmt.Sprintf("/api/v3/slave/credential/onedrive/%d", policyID),
|
||||
nil,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return "", ErrMasterNotFound
|
||||
}
|
||||
385
pkg/cluster/controller_test.go
Normal file
385
pkg/cluster/controller_test.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInitController(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
InitController()
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlaveController_HandleHeartBeat(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
|
||||
// first heart beat
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "2",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
|
||||
a.Len(c.masters, 2)
|
||||
}
|
||||
|
||||
// second heart beat, no fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Empty(c.masters["1"].URL)
|
||||
}
|
||||
|
||||
// second heart beat, fresh
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: "http://127.0.0.1",
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.NoError(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
|
||||
// second heart beat, fresh, url illegal
|
||||
{
|
||||
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
|
||||
SiteID: "1",
|
||||
IsUpdate: true,
|
||||
SiteURL: string([]byte{0x7f}),
|
||||
Node: &model.Node{},
|
||||
})
|
||||
a.Error(err)
|
||||
a.Len(c.masters, 2)
|
||||
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
|
||||
}
|
||||
}
|
||||
|
||||
type nodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n nodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n nodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n nodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n nodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n nodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n nodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n nodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
func TestSlaveController_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
mockNode := &nodeMock{}
|
||||
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Instance: mockNode},
|
||||
},
|
||||
}
|
||||
|
||||
// node node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("2")
|
||||
a.Nil(res)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
}
|
||||
|
||||
// node found
|
||||
{
|
||||
res, err := c.GetAria2Instance("1")
|
||||
a.NotNil(res)
|
||||
a.NoError(err)
|
||||
mockNode.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type requestMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
||||
|
||||
func TestSlaveController_SendNotification(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
|
||||
}
|
||||
|
||||
// gob encode error
|
||||
{
|
||||
type randomType struct{}
|
||||
a.Error(c.SendNotification("1", "", mq.Message{
|
||||
Content: randomType{},
|
||||
}))
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Error(c.SendNotification("1", "s1", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_SubmitTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {
|
||||
jobTracker: map[string]bool{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.True(submitted)
|
||||
}
|
||||
|
||||
// job already submitted
|
||||
{
|
||||
submitted := false
|
||||
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
|
||||
submitted = true
|
||||
}))
|
||||
a.False(submitted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetMasterInfo(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetMasterInfo("2")
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Nil(res)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
res, err := c.GetMasterInfo("1")
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlaveController_GetOneDriveToken(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {},
|
||||
},
|
||||
}
|
||||
|
||||
// node not exit
|
||||
{
|
||||
res, err := c.GetOneDriveToken("2", 1)
|
||||
a.Equal(ErrMasterNotFound, err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
// return none 200
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{StatusCode: http.StatusConflict},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetOneDriveToken("1", 1)
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// master return error
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetOneDriveToken("1", 1)
|
||||
a.Equal(1, err.(serializer.AppError).Code)
|
||||
a.Empty(res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mockRequest := &requestMock{}
|
||||
mockRequest.On("Request", "GET", "/api/v3/slave/credential/onedrive/1", testMock.Anything, testMock.Anything).Return(&request.Response{
|
||||
Response: &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
|
||||
},
|
||||
})
|
||||
c := &slaveController{
|
||||
masters: map[string]MasterInfo{
|
||||
"1": {Client: mockRequest},
|
||||
},
|
||||
}
|
||||
res, err := c.GetOneDriveToken("1", 1)
|
||||
a.NoError(err)
|
||||
a.Equal("expected", res)
|
||||
mockRequest.AssertExpectations(t)
|
||||
}
|
||||
|
||||
}
|
||||
12
pkg/cluster/errors.go
Normal file
12
pkg/cluster/errors.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
|
||||
)
|
||||
272
pkg/cluster/master.go
Normal file
272
pkg/cluster/master.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
deleteTempFileDuration = 60 * time.Second
|
||||
statusRetryDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type MasterNode struct {
|
||||
Model *model.Node
|
||||
aria2RPC rpcService
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type rpcService struct {
|
||||
Caller rpc.Client
|
||||
Initialized bool
|
||||
|
||||
retryDuration time.Duration
|
||||
deletePaddingDuration time.Duration
|
||||
parent *MasterNode
|
||||
options *clientOptions
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.aria2RPC.parent = node
|
||||
node.aria2RPC.retryDuration = statusRetryDuration
|
||||
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
|
||||
node.lock.Unlock()
|
||||
|
||||
node.lock.RLock()
|
||||
if node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return
|
||||
}
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (node *MasterNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
return &serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (node *MasterNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *MasterNode) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Kill 结束aria2请求
|
||||
func (node *MasterNode) Kill() {
|
||||
if node.aria2RPC.Caller != nil {
|
||||
node.aria2RPC.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
if !node.aria2RPC.Initialized {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
defer node.lock.RUnlock()
|
||||
return &node.aria2RPC
|
||||
}
|
||||
|
||||
func (node *MasterNode) IsMater() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (node *MasterNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
r.Initialized = false
|
||||
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if r.Caller != nil {
|
||||
r.Caller.Close()
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err)
|
||||
return err
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
|
||||
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法解析主机 Aria2 配置,%s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.options = &clientOptions{
|
||||
Options: globalOptions,
|
||||
}
|
||||
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
|
||||
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
|
||||
|
||||
r.Caller = caller
|
||||
r.Initialized = err == nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||
r.parent.lock.RLock()
|
||||
// 生成存储路径
|
||||
guid, _ := uuid.NewV4()
|
||||
path := filepath.Join(
|
||||
r.parent.Model.Aria2OptionsSerialized.TempPath,
|
||||
"aria2",
|
||||
guid.String(),
|
||||
)
|
||||
r.parent.lock.RUnlock()
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range r.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := r.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := r.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("无法获取离线下载状态,%s,稍后重试", err)
|
||||
time.Sleep(r.retryDuration)
|
||||
res, err = r.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *rpcService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := r.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) GetConfig() model.Aria2Option {
|
||||
r.parent.lock.RLock()
|
||||
defer r.parent.lock.RUnlock()
|
||||
|
||||
return r.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *rpcService) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
// 避免被aria2占用,异步执行删除
|
||||
go func(d time.Duration, src string) {
|
||||
time.Sleep(d)
|
||||
err := os.RemoveAll(src)
|
||||
if err != nil {
|
||||
util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err)
|
||||
}
|
||||
}(s.deletePaddingDuration, task.Parent)
|
||||
|
||||
return nil
|
||||
}
|
||||
186
pkg/cluster/master_test.go
Normal file
186
pkg/cluster/master_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMasterNode_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{}
|
||||
m.Init(&model.Node{Status: model.NodeSuspend})
|
||||
a.Equal(model.NodeSuspend, m.DBModel().Status)
|
||||
m.Init(&model.Node{Aria2Enabled: true})
|
||||
}
|
||||
|
||||
func TestMasterNode_DummyMethods(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Model.ID = 5
|
||||
a.Equal(m.Model.ID, m.ID())
|
||||
|
||||
res, err := m.Ping(&serializer.NodePingReq{})
|
||||
a.NoError(err)
|
||||
a.NotNil(res)
|
||||
|
||||
a.True(m.IsActive())
|
||||
a.True(m.IsMater())
|
||||
|
||||
m.SubscribeStatusChange(func(isActive bool, id uint) {})
|
||||
}
|
||||
|
||||
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.False(m.IsFeatureEnabled("aria2"))
|
||||
a.False(m.IsFeatureEnabled("random"))
|
||||
m.Model.Aria2Enabled = true
|
||||
a.True(m.IsFeatureEnabled("aria2"))
|
||||
}
|
||||
|
||||
func TestMasterNode_AuthInstance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
a.NotNil(m.MasterAuthInstance())
|
||||
a.NotNil(m.SlaveAuthInstance())
|
||||
}
|
||||
|
||||
func TestMasterNode_Kill(t *testing.T) {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
}
|
||||
|
||||
m.Kill()
|
||||
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
m.Kill()
|
||||
}
|
||||
|
||||
func TestMasterNode_GetAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.Model.Aria2Enabled = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
m.aria2RPC.Initialized = true
|
||||
a.NotNil(m.GetAria2Instance())
|
||||
}
|
||||
|
||||
func TestRpcService_Init(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{
|
||||
Options: "{",
|
||||
},
|
||||
},
|
||||
aria2RPC: rpcService{},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
|
||||
// failed to decode address
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to decode options
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
a.Error(m.aria2RPC.Init())
|
||||
}
|
||||
|
||||
// failed to initialized
|
||||
{
|
||||
m.Model.Aria2OptionsSerialized.Server = ""
|
||||
m.Model.Aria2OptionsSerialized.Options = "{}"
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
a.Error(m.aria2RPC.Init())
|
||||
a.False(m.aria2RPC.Initialized)
|
||||
}
|
||||
}
|
||||
|
||||
func getTestRPCNode() *MasterNode {
|
||||
m := &MasterNode{
|
||||
Model: &model.Node{
|
||||
Aria2OptionsSerialized: model.Aria2Option{},
|
||||
},
|
||||
aria2RPC: rpcService{
|
||||
options: &clientOptions{
|
||||
Options: map[string]interface{}{"1": "1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
m.aria2RPC.parent = m
|
||||
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
|
||||
m.aria2RPC.Caller = caller
|
||||
return m
|
||||
}
|
||||
|
||||
func TestRpcService_CreateTask(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Status(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
res, err := m.aria2RPC.Status(&model.Download{})
|
||||
a.Error(err)
|
||||
a.Empty(res)
|
||||
}
|
||||
|
||||
func TestRpcService_Cancel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.Error(m.aria2RPC.Cancel(&model.Download{}))
|
||||
}
|
||||
|
||||
func TestRpcService_Select(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
|
||||
a.NotNil(m.aria2RPC.GetConfig())
|
||||
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
|
||||
}
|
||||
|
||||
func TestRpcService_DeleteTempFile(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
m := getTestRPCNode()
|
||||
fdName := "TestRpcService_DeleteTempFile"
|
||||
a.NoError(os.Mkdir(fdName, 0644))
|
||||
|
||||
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
a.False(util.Exists(fdName))
|
||||
}
|
||||
60
pkg/cluster/node.go
Normal file
60
pkg/cluster/node.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
// Init a node from database model
|
||||
Init(node *model.Node)
|
||||
|
||||
// Check if given feature is enabled
|
||||
IsFeatureEnabled(feature string) bool
|
||||
|
||||
// Subscribe node status change to a callback function
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
|
||||
// Ping the node
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
|
||||
// Returns if the node is active
|
||||
IsActive() bool
|
||||
|
||||
// Get instances for aria2 calls
|
||||
GetAria2Instance() common.Aria2
|
||||
|
||||
// Returns unique id of this node
|
||||
ID() uint
|
||||
|
||||
// Kill node and recycle resources
|
||||
Kill()
|
||||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
|
||||
// Get auth instance used to check RPC call from slave to master
|
||||
MasterAuthInstance() auth.Auth
|
||||
|
||||
// Get auth instance used to check RPC call from master to slave
|
||||
SlaveAuthInstance() auth.Auth
|
||||
|
||||
// Get node DB model
|
||||
DBModel() *model.Node
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
func NewNodeFromDBModel(node *model.Node) Node {
|
||||
switch node.Type {
|
||||
case model.SlaveNodeType:
|
||||
slave := &SlaveNode{}
|
||||
slave.Init(node)
|
||||
return slave
|
||||
default:
|
||||
master := &MasterNode{}
|
||||
master.Init(node)
|
||||
return master
|
||||
}
|
||||
}
|
||||
17
pkg/cluster/node_test.go
Normal file
17
pkg/cluster/node_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewNodeFromDBModel(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.SlaveNodeType,
|
||||
}))
|
||||
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
|
||||
Type: model.MasterNodeType,
|
||||
}))
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user