Compare commits

..

2 Commits
3.7.1 ... 3.4.3

Author SHA1 Message Date
HFO4
a910a20457 chore: build manually 2022-04-16 14:27:36 +08:00
HFO4
c4fc4a803d update version number 2022-04-16 14:21:33 +08:00
237 changed files with 6273 additions and 12024 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
custom: ["https://cloudreve.org/pricing"]
custom: ["https://cloudreve.org/buy.php"]

61
.github/stale.yml vendored
View File

@@ -1,61 +0,0 @@
# Configuration for probot-stale - https://github.com/probot/stale
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 360
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
daysUntilClose: 30
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
onlyLabels: []
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
exemptLabels:
- pinned
- security
- "[Status] Maybe Later"
# Set to true to ignore issues in a project (defaults to false)
exemptProjects: true
# Set to true to ignore issues in a milestone (defaults to false)
exemptMilestones: true
# Set to true to ignore issues with an assignee (defaults to false)
exemptAssignees: true
# Label to use when marking as stale
staleLabel: wontfix
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Comment to post when removing the stale label.
# unmarkComment: >
# Your comment here.
# Comment to post when closing a stale Issue or Pull Request.
# closeComment: >
# Your comment here.
# Limit the number of actions per hour, from 1-30. Default is 30
limitPerRun: 30
# Limit to only `issues` or `pulls`
# only: issues
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
# pulls:
# daysUntilStale: 30
# markComment: >
# This pull request has been automatically marked as stale because it has not had
# recent activity. It will be closed if no further activity occurs. Thank you
# for your contributions.
# issues:
# exemptLabels:
# - confirmed

View File

@@ -7,25 +7,53 @@ jobs:
name: Build
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.20
uses: actions/setup-go@v2
with:
go-version: "1.20"
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v2
with:
clean: false
submodules: "recursive"
- run: |
git fetch --prune --unshallow --tags
- name: Set up Go 1.13
uses: actions/setup-go@v1
with:
go-version: 1.13
id: go
- name: Build and Release
uses: goreleaser/goreleaser-action@v4
with:
distribution: goreleaser
version: latest
args: release --clean --skip-validate
env:
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
with:
clean: false
submodules: 'recursive'
- run: |
git fetch --prune --unshallow --tags
- name: Get dependencies and build
run: |
go get github.com/rakyll/statik
export PATH=$PATH:~/go/bin/
statik -src=models -f
sudo apt-get update
sudo apt-get -y install gcc-mingw-w64-x86-64
sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross
sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross
chmod +x ./build.sh
./build.sh -r b
- name: Upload binary files (windows_amd64)
uses: actions/upload-artifact@v2
with:
name: cloudreve_windows_amd64
path: release/cloudreve*windows_amd64.*
- name: Upload binary files (linux_amd64)
uses: actions/upload-artifact@v2
with:
name: cloudreve_linux_amd64
path: release/cloudreve*linux_amd64.*
- name: Upload binary files (linux_arm)
uses: actions/upload-artifact@v2
with:
name: cloudreve_linux_arm
path: release/cloudreve*linux_arm.*
- name: Upload binary files (linux_arm64)
uses: actions/upload-artifact@v2
with:
name: cloudreve_linux_arm64
path: release/cloudreve*linux_arm64.*

View File

@@ -1,57 +0,0 @@
name: Build and push docker image
on:
push:
tags:
- 3.* # triggered on every push with tag 3.*
workflow_dispatch: # or just on button clicked
jobs:
docker-build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- run: git fetch --prune --unshallow
- name: Setup Environments
id: envs
run: |
CLOUDREVE_LATEST_TAG=$(git describe --tags --abbrev=0)
DOCKER_IMAGE="cloudreve/cloudreve"
echo "RELEASE_VERSION=${GITHUB_REF#refs}"
TAGS="${DOCKER_IMAGE}:latest,${DOCKER_IMAGE}:${CLOUDREVE_LATEST_TAG}"
echo "CLOUDREVE_LATEST_TAG:${CLOUDREVE_LATEST_TAG}"
echo ::set-output name=tags::${TAGS}
- name: Setup QEMU Emulator
uses: docker/setup-qemu-action@master
with:
platforms: all
- name: Setup Docker Buildx Command
id: buildx
uses: docker/setup-buildx-action@master
- name: Login to Dockerhub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build Docker Image and Push
id: docker_build
uses: docker/build-push-action@v2
with:
push: true
builder: ${{ steps.buildx.outputs.name }}
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64,linux/arm/v7
tags: ${{ steps.envs.outputs.tags }}
- name: Update Docker Hub Description
uses: peter-evans/dockerhub-description@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
repository: cloudreve/cloudreve
short-description: ${{ github.event.repository.description }}
- name: Image Digest
run: echo ${{ steps.docker_build.outputs.digest }}

View File

@@ -2,34 +2,46 @@ name: Test
on:
pull_request:
branches:
branches:
- master
push:
branches: [master]
branches: [ master ]
jobs:
test:
name: Test
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.18
uses: actions/setup-go@v2
with:
go-version: "1.18"
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v2
with:
submodules: "recursive"
- name: Set up Go 1.13
uses: actions/setup-go@v1
with:
go-version: 1.13
id: go
- name: Build static files
run: |
mkdir assets/build
touch assets/build/test.html
- name: Check out code into the Go module directory
uses: actions/checkout@v2
with:
submodules: 'recursive'
- name: Test
run: go test -coverprofile=coverage.txt -covermode=atomic ./...
- name: Get dependencies
run: |
go get github.com/rakyll/statik
export PATH=$PATH:~/go/bin/
statik -src=models -f
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v2
- name: Test
run: go test -coverprofile=coverage.txt -covermode=atomic ./...
- name: Upload binary files (linux_arm)
uses: actions/upload-artifact@v2
with:
name: cloudreve_linux_arm
path: release/cloudreve*linux_arm.*
- name: Upload binary files (linux_arm64)
uses: actions/upload-artifact@v2
with:
name: cloudreve_linux_arm64
path: release/cloudreve*linux_arm64.*

3
.gitignore vendored
View File

@@ -8,7 +8,6 @@ cloudreve
*.db
*.bin
/release/
assets.zip
# Test binary, build with `go test -c`
*.test
@@ -29,5 +28,3 @@ version.lock
conf/conf.ini
/statik/
.vscode/
dist/

View File

@@ -1,121 +0,0 @@
env:
- CI=false
- GENERATE_SOURCEMAP=false
before:
hooks:
- go mod tidy
- sh -c "cd assets && rm -rf build && yarn install --network-timeout 1000000 && yarn run build && cd ../ && zip -r - assets/build >assets.zip"
builds:
-
env:
- CGO_ENABLED=0
binary: cloudreve
ldflags:
- -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion={{.Tag}}' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit={{.ShortCommit}}'
goos:
- linux
- windows
- darwin
goarch:
- amd64
- arm
- arm64
goarm:
- 5
- 6
- 7
ignore:
- goos: windows
goarm: 5
- goos: windows
goarm: 6
- goos: windows
goarm: 7
archives:
- format: tar.gz
# this name template makes the OS and Arch compatible with the results of uname.
name_template: >-
cloudreve_{{.Tag}}_{{- .Os }}_{{ .Arch }}
{{- if .Arm }}v{{ .Arm }}{{ end }}
# use zip for windows archives
format_overrides:
- goos: windows
format: zip
checksum:
name_template: 'checksums.txt'
snapshot:
name_template: "{{ incpatch .Version }}-next"
changelog:
sort: asc
filters:
exclude:
- '^docs:'
- '^test:'
release:
draft: true
prerelease: auto
target_commitish: '{{ .Commit }}'
name_template: "{{.Version}}"
dockers:
-
dockerfile: Dockerfile
use: buildx
build_flag_templates:
- "--platform=linux/amd64"
goos: linux
goarch: amd64
goamd64: v1
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-amd64"
-
dockerfile: Dockerfile
use: buildx
build_flag_templates:
- "--platform=linux/arm64"
goos: linux
goarch: arm64
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-arm64"
-
dockerfile: Dockerfile
use: buildx
build_flag_templates:
- "--platform=linux/arm/v6"
goos: linux
goarch: arm
goarm: '6'
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-armv6"
-
dockerfile: Dockerfile
use: buildx
build_flag_templates:
- "--platform=linux/arm/v7"
goos: linux
goarch: arm
goarm: '7'
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-armv7"
docker_manifests:
- name_template: "cloudreve/cloudreve:latest"
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-amd64"
- "cloudreve/cloudreve:{{ .Tag }}-arm64"
- "cloudreve/cloudreve:{{ .Tag }}-armv6"
- "cloudreve/cloudreve:{{ .Tag }}-armv7"
- name_template: "cloudreve/cloudreve:{{ .Tag }}"
image_templates:
- "cloudreve/cloudreve:{{ .Tag }}-amd64"
- "cloudreve/cloudreve:{{ .Tag }}-arm64"
- "cloudreve/cloudreve:{{ .Tag }}-armv6"
- "cloudreve/cloudreve:{{ .Tag }}-armv7"

30
.travis.yml Normal file
View File

@@ -0,0 +1,30 @@
language: go
go:
- 1.13.x
node_js: "12.16.3"
git:
depth: 1
install:
- go get github.com/rakyll/statik
before_script:
- statik -src=models -f
script:
- go test -coverprofile=coverage.txt -covermode=atomic ./...
after_success:
- bash <(curl -s https://codecov.io/bash)
before_deploy:
- sudo apt-get update
- sudo apt-get -y install gcc-mingw-w64-x86-64
- sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross
- sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross
- chmod +x ./build.sh
- ./build.sh -r b
deploy:
provider: releases
api_key: $GITHUB_TOKEN
file_glob: true
file: release/*
draft: true
skip_cleanup: true
on:
tags: true

View File

@@ -1,17 +1,72 @@
FROM alpine:latest
# build frontend
FROM node:lts-buster AS fe-builder
WORKDIR /cloudreve
COPY cloudreve ./cloudreve
COPY ./assets /assets
RUN apk update \
&& apk add --no-cache tzdata \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone \
&& chmod +x ./cloudreve \
&& mkdir -p /data/aria2 \
&& chmod -R 766 /data/aria2
WORKDIR /assets
EXPOSE 5212
VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"]
# If encountered problems like JavaScript heap out of memory, please uncomment the following options
ENV NODE_OPTIONS --max_old_space_size=4096
ENTRYPOINT ["./cloudreve"]
# yarn repo connection is unstable, adjust the network timeout to 10 min.
RUN set -ex \
&& yarn install --network-timeout 600000 \
&& yarn run build
# build backend
FROM golang:1.15.1-alpine3.12 AS be-builder
ENV GO111MODULE on
COPY . /go/src/github.com/cloudreve/Cloudreve/v3
COPY --from=fe-builder /assets/build/ /go/src/github.com/cloudreve/Cloudreve/v3/assets/build/
WORKDIR /go/src/github.com/cloudreve/Cloudreve/v3
RUN set -ex \
&& apk upgrade \
&& apk add gcc libc-dev git \
&& export COMMIT_SHA=$(git rev-parse --short HEAD) \
&& export VERSION=$(git describe --tags) \
&& (cd && go get github.com/rakyll/statik) \
&& statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico -f \
&& go install -ldflags "-X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=${VERSION}' \
-X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=${COMMIT_SHA}'\
-w -s"
# build final image
FROM alpine:3.12 AS dist
LABEL maintainer="mritd <mritd@linux.com>"
# we use the Asia/Shanghai timezone by default, you can be modified
# by `docker build --build-arg=TZ=Other_Timezone ...`
ARG TZ="Asia/Shanghai"
ENV TZ ${TZ}
COPY --from=be-builder /go/bin/Cloudreve /cloudreve/cloudreve
COPY docker-bootstrap.sh /cloudreve/bootstrap.sh
RUN apk upgrade \
&& apk add bash tzdata aria2 \
&& ln -s /cloudreve/cloudreve /usr/bin/cloudreve \
&& ln -sf /usr/share/zoneinfo/${TZ} /etc/localtime \
&& echo ${TZ} > /etc/timezone \
&& rm -rf /var/cache/apk/* \
&& mkdir /etc/cloudreve \
&& ln -s /etc/cloudreve/cloureve.db /cloudreve/cloudreve.db \
&& ln -s /etc/cloudreve/conf.ini /cloudreve/conf.ini
# cloudreve use tcp 5212 port by default
EXPOSE 5212/tcp
# cloudreve stores all files(including executable file) in the `/cloudreve`
# directory by default; users should mount the configfile to the `/etc/cloudreve`
# directory by themselves for persistence considerations, and the data storage
# directory recommends using `/data` directory.
VOLUME /etc/cloudreve
VOLUME /data
ENTRYPOINT ["sh", "/cloudreve/bootstrap.sh"]

135
README.md
View File

@@ -1,5 +1,3 @@
[中文版本](https://github.com/cloudreve/Cloudreve/blob/master/README_zh-CN.md)
<h1 align="center">
<br>
<a href="https://cloudreve.org/" alt="logo" ><img src="https://raw.githubusercontent.com/cloudreve/frontend/master/public/static/img/logo192.png" width="150"/></a>
@@ -7,98 +5,131 @@
Cloudreve
<br>
</h1>
<h4 align="center">Self-hosted file management system with muilt-cloud support.</h4>
<h4 align="center">支持多家云存储驱动的公有云文件系统.</h4>
<p align="center">
<a href="https://github.com/cloudreve/Cloudreve/actions/workflows/test.yml">
<img src="https://img.shields.io/github/actions/workflow/status/cloudreve/Cloudreve/test.yml?branch=master&style=flat-square"
alt="GitHub Test Workflow">
<a href="https://travis-ci.com/github/cloudreve/Cloudreve/">
<img src="https://img.shields.io/travis/com/cloudreve/Cloudreve?style=flat-square"
alt="travis">
</a>
<a href="https://codecov.io/gh/cloudreve/Cloudreve"><img src="https://img.shields.io/codecov/c/github/cloudreve/Cloudreve?style=flat-square"></a>
<a href="https://goreportcard.com/report/github.com/cloudreve/Cloudreve">
<img src="https://goreportcard.com/badge/github.com/cloudreve/Cloudreve?style=flat-square">
</a>
<a href="https://github.com/cloudreve/Cloudreve/releases">
<img src="https://img.shields.io/github/v/release/cloudreve/Cloudreve?include_prereleases&style=flat-square" />
</a>
<a href="https://hub.docker.com/r/cloudreve/cloudreve">
<img src="https://img.shields.io/docker/image-size/cloudreve/cloudreve?style=flat-square"/>
<img src="https://img.shields.io/github/v/release/cloudreve/Cloudreve?include_prereleases&style=flat-square">
</a>
</p>
<p align="center">
<a href="https://cloudreve.org">Homepage</a> •
<a href="https://demo.cloudreve.org">Demo</a> •
<a href="https://forum.cloudreve.org/">Discussion</a> •
<a href="https://docs.cloudreve.org/v/en/">Documents</a> •
<a href="https://github.com/cloudreve/Cloudreve/releases">Download</a> •
<a href="https://t.me/cloudreve_official">Telegram Group</a>
<a href="#scroll-License">License</a>
<a href="https://demo.cloudreve.org">演示站</a> •
<a href="https://forum.cloudreve.org/">讨论社区</a> •
<a href="https://docs.cloudreve.org/">文档</a> •
<a href="https://github.com/cloudreve/Cloudreve/releases">下载</a> •
<a href="https://t.me/cloudreve_official">Telegram 群组</a> •
<a href="#scroll-许可证">许可证</a>
</p>
![Screenshot](https://raw.githubusercontent.com/cloudreve/docs/master/images/homepage.png)
## :sparkles: Features
## :sparkles: 特性
* :cloud: Support storing files into Local storage, Remote storage, Qiniu, Aliyun OSS, Tencent COS, Upyun, OneDrive, S3 compatible API.
* :outbox_tray: Upload/Download in directly transmission with speed limiting support.
* 💾 Integrate with Aria2 to download files offline, use multiple download nodes to share the load.
* 📚 Compress/Extract files, download files in batch.
* 💻 WebDAV support covering all storage providers.
* :zap:Drag&Drop to upload files or folders, with streaming upload processing.
* :card_file_box: Drag & Drop to manage your files.
* :family_woman_girl_boy: Multi-users with multi-groups.
* :link: Create share links for files and folders with expiration date.
* :eye_speech_bubble: Preview videos, images, audios, ePub files online; edit texts, Office documents online.
* :art: Customize theme colors, dark mode, PWA application, SPA, i18n.
* :rocket: All-In-One packing, with all features out-of-the-box.
* :cloud: 支持本机、从机、七牛、阿里云 OSS、腾讯云 COS、又拍云、OneDrive (包括世纪互联版) 作为存储端
* :outbox_tray: 上传/下载 支持客户端直传,支持下载限速
* 💾 可对接 Aria2 离线下载,可使用多个从机机点分担下载任务
* 📚 在线 压缩/解压缩、多文件打包下载
* 💻 覆盖全部存储策略的 WebDAV 协议支持
* :zap: 拖拽上传、目录上传、流式上传处理
* :card_file_box: 文件拖拽管理
* :family_woman_girl_boy: 多用户、用户组
* :link: 创建文件、目录的分享链接,可设定自动过期
* :eye_speech_bubble: 视频、图像、音频、文本、Office 文档在线预览
* :art: 自定义配色、黑暗模式、PWA 应用、全站单页应用
* :rocket: All-In-One 打包,开箱即用
* 🌈 ... ...
## :hammer_and_wrench: Deploy
## :hammer_and_wrench: 部署
Download the main binary for your target machine OS, CPU architecture and run it directly.
下载适用于您目标机器操作系统、CPU架构的主程序直接运行即可。
```shell
# Extract Cloudreve binary
# 解压程序包
tar -zxvf cloudreve_VERSION_OS_ARCH.tar.gz
# Grant execute permission
# 赋予执行权限
chmod +x ./cloudreve
# Start Cloudreve
# 启动 Cloudreve
./cloudreve
```
The above is a minimum deploy example, you can refer to [Getting started](https://docs.cloudreve.org/v/en/getting-started/install) for a completed deployment.
以上为最简单的部署示例,您可以参考 [文档 - 起步](https://docs.cloudreve.org/) 进行更为完善的部署。
## :gear: Build
## :gear: 构建
You need to have `Go >= 1.18`, `node.js`, `yarn`, `zip`, [goreleaser](https://goreleaser.com/intro/) and other necessary dependencies before you can build it yourself.
自行构建前需要拥有 `Go >= 1.13``yarn`等必要依赖。
#### Install goreleaser
```shell
go install github.com/goreleaser/goreleaser@latest
```
#### Clone the code
#### 克隆代码
```shell
git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git
```
#### Compile
#### 构建静态资源
```shell
goreleaser build --clean --single-target --snapshot
# 进入前端子模块
cd assets
# 安装依赖
yarn install
# 开始构建
yarn run build
```
## :alembic: Stacks
#### 嵌入静态资源
* [Go](https://golang.org/) + [Gin](https://github.com/gin-gonic/gin)
```shell
# 回到项目主目录
cd ../
# 安装 statik, 用于嵌入静态资源
go get github.com/rakyll/statik
# 开始嵌入
statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico -f
```
#### 编译项目
```shell
# 获得当前版本号、Commit
export COMMIT_SHA=$(git rev-parse --short HEAD)
export VERSION=$(git describe --tags)
# 开始编译
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
```
你也可以使用项目根目录下的`build.sh`快速开始构建:
```shell
./build.sh [-a] [-c] [-b] [-r]
a - 构建静态资源
c - 编译二进制文件
b - 构建前端 + 编译二进制文件
r - 交叉编译构建用于release的版本
```
## :alembic: 技术栈
* [Go ](https://golang.org/) + [Gin](https://github.com/gin-gonic/gin)
* [React](https://github.com/facebook/react) + [Redux](https://github.com/reduxjs/redux) + [Material-UI](https://github.com/mui-org/material-ui)
## :scroll: License
## :scroll: 许可证
GPL V3
---
> GitHub [@HFO4](https://github.com/HFO4) &nbsp;&middot;&nbsp;
> Twitter [@abslant00](https://twitter.com/abslant00)

View File

@@ -1,104 +0,0 @@
[English Version](https://github.com/cloudreve/Cloudreve/blob/master/README.md)
<h1 align="center">
<br>
<a href="https://cloudreve.org/" alt="logo" ><img src="https://raw.githubusercontent.com/cloudreve/frontend/master/public/static/img/logo192.png" width="150"/></a>
<br>
Cloudreve
<br>
</h1>
<h4 align="center">支持多家云存储驱动的公有云文件系统.</h4>
<p align="center">
<a href="https://github.com/cloudreve/Cloudreve/actions/workflows/test.yml">
<img src="https://img.shields.io/github/actions/workflow/status/cloudreve/Cloudreve/test.yml?branch=master&style=flat-square"
alt="GitHub Test Workflow">
</a>
<a href="https://codecov.io/gh/cloudreve/Cloudreve"><img src="https://img.shields.io/codecov/c/github/cloudreve/Cloudreve?style=flat-square"></a>
<a href="https://goreportcard.com/report/github.com/cloudreve/Cloudreve">
<img src="https://goreportcard.com/badge/github.com/cloudreve/Cloudreve?style=flat-square">
</a>
<a href="https://github.com/cloudreve/Cloudreve/releases">
<img src="https://img.shields.io/github/v/release/cloudreve/Cloudreve?include_prereleases&style=flat-square" />
</a>
<a href="https://hub.docker.com/r/cloudreve/cloudreve">
<img src="https://img.shields.io/docker/image-size/cloudreve/cloudreve?style=flat-square"/>
</a>
</p>
<p align="center">
<a href="https://cloudreve.org">主页</a> •
<a href="https://demo.cloudreve.org">演示站</a> •
<a href="https://forum.cloudreve.org/">讨论社区</a> •
<a href="https://docs.cloudreve.org/">文档</a> •
<a href="https://github.com/cloudreve/Cloudreve/releases">下载</a> •
<a href="https://t.me/cloudreve_official">Telegram 群组</a> •
<a href="#scroll-许可证">许可证</a>
</p>
![Screenshot](https://raw.githubusercontent.com/cloudreve/docs/master/images/homepage.png)
## :sparkles: 特性
* :cloud: 支持本机、从机、七牛、阿里云 OSS、腾讯云 COS、又拍云、OneDrive (包括世纪互联版) 、S3兼容协议 作为存储端
* :outbox_tray: 上传/下载 支持客户端直传,支持下载限速
* 💾 可对接 Aria2 离线下载,可使用多个从机节点分担下载任务
* 📚 在线 压缩/解压缩、多文件打包下载
* 💻 覆盖全部存储策略的 WebDAV 协议支持
* :zap: 拖拽上传、目录上传、流式上传处理
* :card_file_box: 文件拖拽管理
* :family_woman_girl_boy: 多用户、用户组、多存储策略
* :link: 创建文件、目录的分享链接,可设定自动过期
* :eye_speech_bubble: 视频、图像、音频、 ePub 在线预览文本、Office 文档在线编辑
* :art: 自定义配色、黑暗模式、PWA 应用、全站单页应用、国际化支持
* :rocket: All-In-One 打包,开箱即用
* 🌈 ... ...
## :hammer_and_wrench: 部署
下载适用于您目标机器操作系统、CPU架构的主程序直接运行即可。
```shell
# 解压程序包
tar -zxvf cloudreve_VERSION_OS_ARCH.tar.gz
# 赋予执行权限
chmod +x ./cloudreve
# 启动 Cloudreve
./cloudreve
```
以上为最简单的部署示例,您可以参考 [文档 - 起步](https://docs.cloudreve.org/) 进行更为完善的部署。
## :gear: 构建
自行构建前需要拥有 `Go >= 1.18``node.js``yarn``zip`, [goreleaser](https://goreleaser.com/intro/) 等必要依赖。
#### 安装 goreleaser
```shell
go install github.com/goreleaser/goreleaser@latest
```
#### 克隆代码
```shell
git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git
```
#### 编译项目
```shell
goreleaser build --clean --single-target --snapshot
```
## :alembic: 技术栈
* [Go](https://golang.org/) + [Gin](https://github.com/gin-gonic/gin)
* [React](https://github.com/facebook/react) + [Redux](https://github.com/reduxjs/redux) + [Material-UI](https://github.com/mui-org/material-ui)
## :scroll: 许可证
GPL V3

2
assets

Submodule assets updated: 9f847f466c...8e5b5c2d29

Binary file not shown.

View File

@@ -1,432 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embed provides access to files embedded in the running Go program.
//
// Go source files that import "embed" can use the //go:embed directive
// to initialize a variable of type string, []byte, or FS with the contents of
// files read from the package directory or subdirectories at compile time.
//
// For example, here are three ways to embed a file named hello.txt
// and then print its contents at run time.
//
// Embedding one file into a string:
//
// import _ "embed"
//
// //go:embed hello.txt
// var s string
// print(s)
//
// Embedding one file into a slice of bytes:
//
// import _ "embed"
//
// //go:embed hello.txt
// var b []byte
// print(string(b))
//
// Embedded one or more files into a file system:
//
// import "embed"
//
// //go:embed hello.txt
// var f embed.FS
// data, _ := f.ReadFile("hello.txt")
// print(string(data))
//
// # Directives
//
// A //go:embed directive above a variable declaration specifies which files to embed,
// using one or more path.Match patterns.
//
// The directive must immediately precede a line containing the declaration of a single variable.
// Only blank lines and // line comments are permitted between the directive and the declaration.
//
// The type of the variable must be a string type, or a slice of a byte type,
// or FS (or an alias of FS).
//
// For example:
//
// package server
//
// import "embed"
//
// // content holds our static web server content.
// //go:embed image/* template/*
// //go:embed html/index.html
// var content embed.FS
//
// The Go build system will recognize the directives and arrange for the declared variable
// (in the example above, content) to be populated with the matching files from the file system.
//
// The //go:embed directive accepts multiple space-separated patterns for
// brevity, but it can also be repeated, to avoid very long lines when there are
// many patterns. The patterns are interpreted relative to the package directory
// containing the source file. The path separator is a forward slash, even on
// Windows systems. Patterns may not contain . or .. or empty path elements,
// nor may they begin or end with a slash. To match everything in the current
// directory, use * instead of .. To allow for naming files with spaces in
// their names, patterns can be written as Go double-quoted or back-quoted
// string literals.
//
// If a pattern names a directory, all files in the subtree rooted at that directory are
// embedded (recursively), except that files with names beginning with . or _
// are excluded. So the variable in the above example is almost equivalent to:
//
// // content is our static web server content.
// //go:embed image template html/index.html
// var content embed.FS
//
// The difference is that image/* embeds image/.tempfile while image does not.
// Neither embeds image/dir/.tempfile.
//
// If a pattern begins with the prefix all:, then the rule for walking directories is changed
// to include those files beginning with . or _. For example, all:image embeds
// both image/.tempfile and image/dir/.tempfile.
//
// The //go:embed directive can be used with both exported and unexported variables,
// depending on whether the package wants to make the data available to other packages.
// It can only be used with variables at package scope, not with local variables.
//
// Patterns must not match files outside the package's module, such as .git/* or symbolic links.
// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :.
// Matches for empty directories are ignored. After that, each pattern in a //go:embed line
// must match at least one file or non-empty directory.
//
// If any patterns are invalid or have invalid matches, the build will fail.
//
// # Strings and Bytes
//
// The //go:embed line for a variable of type string or []byte can have only a single pattern,
// and that pattern can match only a single file. The string or []byte is initialized with
// the contents of that file.
//
// The //go:embed directive requires importing "embed", even when using a string or []byte.
// In source files that don't refer to embed.FS, use a blank import (import _ "embed").
//
// # File Systems
//
// For embedding a single file, a variable of type string or []byte is often best.
// The FS type enables embedding a tree of files, such as a directory of static
// web server content, as in the example above.
//
// FS implements the io/fs package's FS interface, so it can be used with any package that
// understands file systems, including net/http, text/template, and html/template.
//
// For example, given the content variable in the example above, we can write:
//
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content))))
//
// template.ParseFS(content, "*.tmpl")
//
// # Tools
//
// To support tools that analyze Go packages, the patterns found in //go:embed lines
// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns,
// and XTestEmbedPatterns fields in the “go help list” output.
package bootstrap
import (
"errors"
"io"
"io/fs"
"time"
)
// An FS is a read-only collection of files, usually initialized with a //go:embed directive.
// When declared without a //go:embed directive, an FS is an empty file system.
//
// An FS is a read-only value, so it is safe to use from multiple goroutines
// simultaneously and also safe to assign values of type FS to each other.
//
// FS implements fs.FS, so it can be used with any package that understands
// file system interfaces, including net/http, text/template, and html/template.
//
// See the package documentation for more details about initializing an FS.
type FS struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
//
// The files list is sorted by name but not by simple string comparison.
// Instead, each file's name takes the form "dir/elem" or "dir/elem/".
// The optional trailing slash indicates that the file is itself a directory.
// The files list is sorted first by dir (if dir is missing, it is taken to be ".")
// and then by base, so this list of files:
//
// p
// q/
// q/r
// q/s/
// q/s/t
// q/s/u
// q/v
// w
//
// is actually sorted as:
//
// p # dir=. elem=p
// q/ # dir=. elem=q
// w/ # dir=. elem=w
// q/r # dir=q elem=r
// q/s/ # dir=q elem=s
// q/v # dir=q elem=v
// q/s/t # dir=q/s elem=t
// q/s/u # dir=q/s elem=u
//
// This order brings directory contents together in contiguous sections
// of the list, allowing a directory read to use binary search to find
// the relevant sequence of entries.
files *[]file
}
// split splits the name into dir and elem as described in the
// comment in the FS struct above. isDir reports whether the
// final trailing slash was present, indicating that name is a directory.
func split(name string) (dir, elem string, isDir bool) {
if name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
// trimSlash trims a trailing slash from name, if present,
// returning the possibly shortened name.
func trimSlash(name string) string {
if len(name) > 0 && name[len(name)-1] == '/' {
return name[:len(name)-1]
}
return name
}
var (
_ fs.ReadDirFS = FS{}
_ fs.ReadFileFS = FS{}
)
// A file is a single file in the FS.
// It implements fs.FileInfo and fs.DirEntry.
type file struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
name string
data string
hash [16]byte // truncated SHA256 hash
}
var (
_ fs.FileInfo = (*file)(nil)
_ fs.DirEntry = (*file)(nil)
)
func (f *file) Name() string { _, elem, _ := split(f.name); return elem }
func (f *file) Size() int64 { return int64(len(f.data)) }
func (f *file) ModTime() time.Time { return time.Time{} }
func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir }
func (f *file) Sys() any { return nil }
func (f *file) Type() fs.FileMode { return f.Mode().Type() }
func (f *file) Info() (fs.FileInfo, error) { return f, nil }
func (f *file) Mode() fs.FileMode {
if f.IsDir() {
return fs.ModeDir | 0555
}
return 0444
}
// dotFile is a file for the root directory,
// which is omitted from the files list in a FS.
var dotFile = &file{name: "./"}
// lookup returns the named file, or nil if it is not present.
func (f FS) lookup(name string) *file {
if !fs.ValidPath(name) {
// The compiler should never emit a file with an invalid name,
// so this check is not strictly necessary (if name is invalid,
// we shouldn't find a match below), but it's a good backstop anyway.
return nil
}
if name == "." {
return dotFile
}
if f.files == nil {
return nil
}
// Binary search to find where name would be in the list,
// and then check if name is at that position.
dir, elem, _ := split(name)
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) && trimSlash(files[i].name) == name {
return &files[i]
}
return nil
}
// readDir returns the list of files corresponding to the directory dir.
func (f FS) readDir(dir string) []file {
if f.files == nil {
return nil
}
// Binary search to find where dir starts and ends in the list
// and then return that slice of the list.
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sortSearch(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
// Open opens the named file for reading and returns it as an fs.File.
//
// The returned file implements io.Seeker when the file is not a directory.
func (f FS) Open(name string) (fs.File, error) {
file := f.lookup(name)
if file == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if file.IsDir() {
return &openDir{file, f.readDir(name), 0}, nil
}
return &openFile{file, 0}, nil
}
// ReadDir reads and returns the entire named directory.
func (f FS) ReadDir(name string) ([]fs.DirEntry, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
dir, ok := file.(*openDir)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")}
}
list := make([]fs.DirEntry, len(dir.files))
for i := range list {
list[i] = &dir.files[i]
}
return list, nil
}
// ReadFile reads and returns the content of the named file.
func (f FS) ReadFile(name string) ([]byte, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
ofile, ok := file.(*openFile)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")}
}
return []byte(ofile.f.data), nil
}
// An openFile is a regular file open for reading.
type openFile struct {
f *file // the file itself
offset int64 // current read offset
}
var (
_ io.Seeker = (*openFile)(nil)
)
func (f *openFile) Close() error { return nil }
func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil }
func (f *openFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid}
}
n := copy(b, f.f.data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.data))
}
if offset < 0 || offset > int64(len(f.f.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
// An openDir is a directory open for reading.
type openDir struct {
f *file // the directory file itself
files []file // the directory contents
offset int // the read offset, an index into the files slice
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.files[d.offset+i]
}
d.offset += n
return list, nil
}
// sortSearch is like sort.Search, avoiding an import.
func sortSearch(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}

View File

@@ -1,75 +0,0 @@
package bootstrap
import (
"archive/zip"
"crypto/sha256"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/pkg/errors"
"io"
"io/fs"
"sort"
"strings"
)
func NewFS(zipContent string) fs.FS {
zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent)))
if err != nil {
util.Log().Panic("Static resource is not a valid zip file: %s", err)
}
var files []file
err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("无法获取[%s]的信息, %s, 跳过...", path, err)
}
if path == "." {
return nil
}
var f file
if d.IsDir() {
f.name = path + "/"
} else {
f.name = path
rc, err := zipReader.Open(path)
if err != nil {
return errors.Errorf("无法打开文件[%s], %s, 跳过...", path, err)
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return errors.Errorf("无法读取文件[%s], %s, 跳过...", path, err)
}
f.data = string(data)
hash := sha256.Sum256(data)
for i := range f.hash {
f.hash[i] = ^hash[i]
}
}
files = append(files, f)
return nil
})
if err != nil {
util.Log().Panic("初始化静态资源失败: %s", err)
}
sort.Slice(files, func(i, j int) bool {
fi, fj := files[i], files[j]
di, ei, _ := split(fi.name)
dj, ej, _ := split(fj.name)
if di != dj {
return di < dj
}
return ei < ej
})
var embedFS FS
embedFS.files = &files
return embedFS
}

View File

@@ -12,13 +12,11 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"io/fs"
)
// Init 初始化启动
func Init(path string, statics fs.FS) {
func Init(path string) {
InitApplication()
conf.Init(path)
// Debug 关闭时,切换为生产模式
@@ -39,7 +37,7 @@ func Init(path string, statics fs.FS) {
{
"both",
func() {
cache.Init(conf.SystemConfig.Mode == "slave")
cache.Init()
},
},
{
@@ -81,7 +79,7 @@ func Init(path string, statics fs.FS) {
{
"master",
func() {
InitStatic(statics)
InitStatic()
},
},
{
@@ -96,16 +94,19 @@ func Init(path string, statics fs.FS) {
auth.Init()
},
},
{
"master",
func() {
wopi.Init()
},
},
}
for _, dependency := range dependencies {
if dependency.mode == conf.SystemConfig.Mode || dependency.mode == "both" {
switch dependency.mode {
case "master":
if conf.SystemConfig.Mode == "master" {
dependency.factory()
}
case "slave":
if conf.SystemConfig.Mode == "slave" {
dependency.factory()
}
default:
dependency.factory()
}
}

View File

@@ -10,9 +10,9 @@ func RunScript(name string) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := invoker.RunDBScript(name, ctx); err != nil {
util.Log().Error("Failed to execute database script: %s", err)
util.Log().Error("数据库脚本执行失败: %s", err)
return
}
util.Log().Info("Finish executing database script %q.", name)
util.Log().Info("数据库脚本 [%s] 执行完毕", name)
}

View File

@@ -1,19 +1,17 @@
package bootstrap
import (
"bufio"
"encoding/json"
"io"
"io/fs"
"io/ioutil"
"net/http"
"path/filepath"
"github.com/pkg/errors"
"path"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
_ "github.com/cloudreve/Cloudreve/v3/statik"
"github.com/gin-contrib/static"
"github.com/rakyll/statik/fs"
)
const StaticFolder = "statics"
@@ -37,100 +35,124 @@ func (b *GinFS) Open(name string) (http.File, error) {
// Exists 文件是否存在
func (b *GinFS) Exists(prefix string, filepath string) bool {
if _, err := b.FS.Open(filepath); err != nil {
return false
}
return true
}
// InitStatic 初始化静态资源文件
func InitStatic(statics fs.FS) {
func InitStatic() {
var err error
if util.Exists(util.RelativePath(StaticFolder)) {
util.Log().Info("Folder with name \"statics\" already exists, it will be used to serve static files.")
util.Log().Info("检测到 statics 目录存在,将使用此目录下的静态资源文件")
StaticFS = static.LocalFile(util.RelativePath("statics"), false)
} else {
// 初始化静态资源
embedFS, err := fs.Sub(statics, "assets/build")
// 检查静态资源的版本
f, err := StaticFS.Open("version.json")
if err != nil {
util.Log().Panic("Failed to initialize static resources: %s", err)
util.Log().Warning("静态资源版本标识文件不存在,请重新构建或删除 statics 目录")
return
}
StaticFS = &GinFS{
FS: http.FS(embedFS),
b, err := ioutil.ReadAll(f)
if err != nil {
util.Log().Warning("无法读取静态资源文件版本,请重新构建或删除 statics 目录")
return
}
var v staticVersion
if err := json.Unmarshal(b, &v); err != nil {
util.Log().Warning("无法解析静态资源文件版本, %s", err)
return
}
staticName := "cloudreve-frontend"
if conf.IsPro == "true" {
staticName += "-pro"
}
if v.Name != staticName {
util.Log().Warning("静态资源版本不匹配,请重新构建或删除 statics 目录")
return
}
if v.Version != conf.RequiredStaticVersion {
util.Log().Warning("静态资源版本不匹配 [当前 %s, 需要: %s],请重新构建或删除 statics 目录", v.Version, conf.RequiredStaticVersion)
return
}
} else {
StaticFS = &GinFS{}
StaticFS.(*GinFS).FS, err = fs.New()
if err != nil {
util.Log().Panic("无法初始化静态资源, %s", err)
}
}
// 检查静态资源的版本
f, err := StaticFS.Open("version.json")
if err != nil {
util.Log().Warning("Missing version identifier file in static resources, please delete \"statics\" folder and rebuild it.")
return
}
b, err := io.ReadAll(f)
if err != nil {
util.Log().Warning("Failed to read version identifier file in static resources, please delete \"statics\" folder and rebuild it.")
return
}
var v staticVersion
if err := json.Unmarshal(b, &v); err != nil {
util.Log().Warning("Failed to parse version identifier file in static resources: %s", err)
return
}
staticName := "cloudreve-frontend"
if conf.IsPro == "true" {
staticName += "-pro"
}
if v.Name != staticName {
util.Log().Warning("Static resource version mismatch, please delete \"statics\" folder and rebuild it.")
return
}
if v.Version != conf.RequiredStaticVersion {
util.Log().Warning("Static resource version mismatch [Current %s, Desired: %s]please delete \"statics\" folder and rebuild it.", v.Version, conf.RequiredStaticVersion)
return
}
}
// Eject 抽离内置静态资源
func Eject(statics fs.FS) {
// 初始化静态资源
embedFS, err := fs.Sub(statics, "assets/build")
func Eject() {
staticFS, err := fs.New()
if err != nil {
util.Log().Panic("Failed to initialize static resources: %s", err)
util.Log().Panic("无法初始化静态资源, %s", err)
}
var walk func(relPath string, d fs.DirEntry, err error) error
walk = func(relPath string, d fs.DirEntry, err error) error {
root, err := staticFS.Open("/")
if err != nil {
util.Log().Panic("根目录不存在, %s", err)
}
var walk func(relPath string, object http.File)
walk = func(relPath string, object http.File) {
stat, err := object.Stat()
if err != nil {
return errors.Errorf("Failed to read info of %q: %s, skipping...", relPath, err)
util.Log().Error("无法获取[%s]的信息, %s, 跳过...", relPath, err)
return
}
if !d.IsDir() {
if !stat.IsDir() {
// 写入文件
out, err := util.CreatNestedFile(filepath.Join(util.RelativePath(""), StaticFolder, relPath))
out, err := util.CreatNestedFile(util.RelativePath(StaticFolder + relPath))
defer out.Close()
if err != nil {
return errors.Errorf("Failed to create file %q: %s, skipping...", relPath, err)
util.Log().Error("无法创建文件[%s], %s, 跳过...", relPath, err)
return
}
util.Log().Info("Ejecting %q...", relPath)
obj, _ := embedFS.Open(relPath)
if _, err := io.Copy(out, bufio.NewReader(obj)); err != nil {
return errors.Errorf("Cannot write file %q: %s, skipping...", relPath, err)
util.Log().Info("导出 [%s]...", relPath)
if _, err := io.Copy(out, object); err != nil {
util.Log().Error("无法写入文件[%s], %s, 跳过...", relPath, err)
return
}
} else {
// 列出目录
objects, err := object.Readdir(0)
if err != nil {
util.Log().Error("无法步入子目录[%s], %s, 跳过...", relPath, err)
return
}
// 递归遍历子目录
for _, newObject := range objects {
newPath := path.Join(relPath, newObject.Name())
newRoot, err := staticFS.Open(newPath)
if err != nil {
util.Log().Error("无法打开对象[%s], %s, 跳过...", newPath, err)
continue
}
walk(newPath, newRoot)
}
}
return nil
}
// util.Log().Info("开始导出内置静态资源...")
err = fs.WalkDir(embedFS, ".", walk)
if err != nil {
util.Log().Error("Error occurs while ejecting static resources: %s", err)
return
}
util.Log().Info("Finish ejecting static resources.")
util.Log().Info("开始导出内置静态资源...")
walk("/", root)
util.Log().Info("内置静态资源导出完成")
}

133
build.sh Executable file
View File

@@ -0,0 +1,133 @@
#!/bin/bash
REPO=$(cd $(dirname $0); pwd)
COMMIT_SHA=$(git rev-parse --short HEAD)
VERSION=$(git describe --tags)
ASSETS="false"
BINARY="false"
RELEASE="false"
debugInfo () {
echo "Repo: $REPO"
echo "Build assets: $ASSETS"
echo "Build binary: $BINARY"
echo "Release: $RELEASE"
echo "Version: $VERSION"
echo "Commit: $COMMIT_SHA"
}
buildAssets () {
cd $REPO
rm -rf assets/build
rm -f statik/statik.go
export CI=false
cd $REPO/assets
yarn install
yarn run build
if ! [ -x "$(command -v statik)" ]; then
export CGO_ENABLED=0
go get github.com/rakyll/statik
fi
cd $REPO
statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico,*.ttf -f
}
buildBinary () {
cd $REPO
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
}
_build() {
local osarch=$1
IFS=/ read -r -a arr <<<"$osarch"
os="${arr[0]}"
arch="${arr[1]}"
gcc="${arr[2]}"
# Go build to build the binary.
export GOOS=$os
export GOARCH=$arch
export CC=$gcc
export CGO_ENABLED=1
if [ -n "$VERSION" ]; then
out="release/cloudreve_${VERSION}_${os}_${arch}"
else
out="release/cloudreve_${COMMIT_SHA}_${os}_${arch}"
fi
go build -a -o "${out}" -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
if [ "$os" = "windows" ]; then
mv $out release/cloudreve.exe
zip -j -q "${out}.zip" release/cloudreve.exe
rm -f "release/cloudreve.exe"
else
mv $out release/cloudreve
tar -zcvf "${out}.tar.gz" -C release cloudreve
rm -f "release/cloudreve"
fi
}
release(){
cd $REPO
## List of architectures and OS to test coss compilation.
SUPPORTED_OSARCH="linux/amd64/gcc linux/arm/arm-linux-gnueabihf-gcc windows/amd64/x86_64-w64-mingw32-gcc linux/arm64/aarch64-linux-gnu-gcc"
echo "Release builds for OS/Arch/CC: ${SUPPORTED_OSARCH}"
for each_osarch in ${SUPPORTED_OSARCH}; do
_build "${each_osarch}"
done
}
usage() {
echo "Usage: $0 [-a] [-c] [-b] [-r]" 1>&2;
exit 1;
}
while getopts "bacr:d" o; do
case "${o}" in
b)
ASSETS="true"
BINARY="true"
;;
a)
ASSETS="true"
;;
c)
BINARY="true"
;;
r)
ASSETS="true"
RELEASE="true"
;;
d)
DEBUG="true"
;;
*)
usage
;;
esac
done
shift $((OPTIND-1))
if [ "$DEBUG" = "true" ]; then
debugInfo
fi
if [ "$ASSETS" = "true" ]; then
buildAssets
fi
if [ "$BINARY" = "true" ]; then
buildBinary
fi
if [ "$RELEASE" = "true" ]; then
release
fi

15
docker-bootstrap.sh Normal file
View File

@@ -0,0 +1,15 @@
#!/bin/sh
GREEN='\033[0;32m'
RESET='\033[0m'
if [ ! -f /etc/cloudreve/aria2c.conf ]; then
echo -e "[${GREEN}aria2c${RESET}] aria2c config not found. Generating..."
secret=$(tr -dc A-Za-z0-9 </dev/urandom | head -c 13)
echo -e "[${GREEN}aria2c${RESET}] Generated port: 6800, secret: $secret"
cat <<EOF > /etc/cloudreve/aria2c.conf
enable-rpc=true
rpc-listen-port=6800
rpc-secret=$secret
EOF
fi
aria2c --conf-path /etc/cloudreve/aria2c.conf -D
cloudreve

View File

@@ -1,33 +0,0 @@
version: "3.8"
services:
cloudreve:
container_name: cloudreve
image: cloudreve/cloudreve:latest
restart: unless-stopped
ports:
- "5212:5212"
volumes:
- temp_data:/data
- ./cloudreve/uploads:/cloudreve/uploads
- ./cloudreve/conf.ini:/cloudreve/conf.ini
- ./cloudreve/cloudreve.db:/cloudreve/cloudreve.db
- ./cloudreve/avatar:/cloudreve/avatar
depends_on:
- aria2
aria2:
container_name: aria2
image: p3terx/aria2-pro # third party image, please keep notice what you are doing
restart: unless-stopped
environment:
- RPC_SECRET=your_aria_rpc_token # aria rpc token, customize your own
- RPC_PORT=6800
volumes:
- ./aria2/config:/config
- temp_data:/data
volumes:
temp_data:
driver: local
driver_opts:
type: none
device: $PWD/data
o: bind

163
go.mod
View File

@@ -1,172 +1,47 @@
module github.com/cloudreve/Cloudreve/v3
go 1.18
go 1.13
require (
github.com/DATA-DOG/go-sqlmock v1.3.3
github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible
github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible
github.com/aws/aws-sdk-go v1.31.5
github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499
github.com/fatih/color v1.9.0
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect
github.com/duo-labs/webauthn v0.0.0-20191119193225-4bf9a0f776d4
github.com/fatih/color v1.7.0
github.com/gin-contrib/cors v1.3.0
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/sessions v0.0.1
github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2
github.com/gin-gonic/gin v1.8.1
github.com/glebarez/go-sqlite v1.20.3
github.com/gin-gonic/gin v1.5.0
github.com/go-ini/ini v1.50.0
github.com/go-mail/mail v2.3.1+incompatible
github.com/go-playground/validator/v10 v10.11.0
github.com/gofrs/uuid v4.0.0+incompatible
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-querystring v1.0.0
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-version v1.3.0
github.com/gorilla/websocket v1.4.1
github.com/hashicorp/go-version v1.2.0
github.com/jinzhu/gorm v1.9.11
github.com/juju/ratelimit v1.0.1
github.com/mholt/archiver/v4 v4.0.0-alpha.6
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.2.0
github.com/qiniu/go-sdk/v7 v7.11.1
github.com/qiniu/api.v7/v7 v7.4.0
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1
github.com/rakyll/statik v0.1.7
github.com/robfig/cron/v3 v3.0.1
github.com/smartystreets/goconvey v1.6.4 // indirect
github.com/speps/go-hashids v2.0.0+incompatible
github.com/stretchr/testify v1.7.2
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393
github.com/stretchr/testify v1.5.1
github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible
github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac
github.com/upyun/go-sdk v2.1.0+incompatible
golang.org/x/image v0.0.0-20211028202545-6944b10bf410
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba
)
require (
cloud.google.com/go v0.81.0 // indirect
github.com/andybalholm/brotli v1.0.4 // indirect
github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bgentry/speakeasy v0.1.0 // indirect
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.1.1 // indirect
github.com/cloudflare/cfssl v1.6.1 // indirect
github.com/cncf/udpa/go v0.0.0-20210322005330-6414d713912e // indirect
github.com/coreos/go-semver v0.3.0 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/denisenkom/go-mssqldb v0.0.0-20190515213511-eb9f6a1743f3 // indirect
github.com/dsnet/compress v0.0.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d // indirect
github.com/envoyproxy/protoc-gen-validate v0.6.1 // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/fullstorydev/grpcurl v1.8.1 // indirect
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.9.8 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.1.0 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/mock v1.5.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jhump/protoreflect v1.8.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jmespath/go-jmespath v0.3.0 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/lib/pq v1.10.3 // indirect
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mozillazg/go-httpheader v0.2.1 // indirect
github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pelletier/go-toml/v2 v2.0.2 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.10.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.24.0 // indirect
github.com/prometheus/procfs v0.6.0 // indirect
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/cobra v1.1.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/therootcompany/xz v1.0.1 // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect
github.com/urfave/cli v1.22.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
go.etcd.io/bbolt v1.3.5 // indirect
go.etcd.io/etcd/api/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/client/v2 v2.305.0-alpha.0 // indirect
go.etcd.io/etcd/client/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/etcdctl/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/pkg/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/raft/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/server/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/tests/v3 v3.5.0-alpha.0 // indirect
go.etcd.io/etcd/v3 v3.5.0-alpha.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.16.0 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
golang.org/x/mod v0.4.2 // indirect
golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect
golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.4.0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a // indirect
google.golang.org/grpc v1.37.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
golang.org/x/text v0.3.6
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/go-playground/validator.v9 v9.29.1
gopkg.in/ini.v1 v1.51.0 // indirect
gopkg.in/mail.v2 v2.3.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.2 // indirect
modernc.org/mathutil v1.5.0 // indirect
modernc.org/memory v1.5.0 // indirect
modernc.org/sqlite v1.20.3 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
)

1261
go.sum

File diff suppressed because it is too large Load Diff

113
main.go
View File

@@ -1,19 +1,9 @@
package main
import (
"context"
_ "embed"
"flag"
"io/fs"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/cloudreve/Cloudreve/v3/bootstrap"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/routers"
@@ -25,32 +15,18 @@ var (
scriptName string
)
//go:embed assets.zip
var staticZip string
var staticFS fs.FS
func init() {
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "Path to the config file.")
flag.BoolVar(&isEject, "eject", false, "Eject all embedded static files.")
flag.StringVar(&scriptName, "database-script", "", "Name of database util script.")
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径")
flag.BoolVar(&isEject, "eject", false, "导出内置静态资源")
flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本")
flag.Parse()
staticFS = bootstrap.NewFS(staticZip)
bootstrap.Init(confPath, staticFS)
bootstrap.Init(confPath)
}
func main() {
// 关闭数据库连接
defer func() {
if model.DB != nil {
model.DB.Close()
}
}()
if isEject {
// 开始导出内置静态资源文件
bootstrap.Eject(staticFS)
bootstrap.Eject()
return
}
@@ -61,82 +37,29 @@ func main() {
}
api := routers.InitRouter()
api.TrustedPlatform = conf.SystemConfig.ProxyHeader
server := &http.Server{Handler: api}
// 收到信号后关闭服务器
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
sig := <-sigChan
util.Log().Info("Signal %s received, shutting down server...", sig)
ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
defer cancel()
}
err := server.Shutdown(ctx)
if err != nil {
util.Log().Error("Failed to shutdown server: %s", err)
}
}()
// 如果启用了SSL
if conf.SSLConfig.CertPath != "" {
util.Log().Info("Listening to %q", conf.SSLConfig.Listen)
server.Addr = conf.SSLConfig.Listen
if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("Failed to listen to %q: %s", conf.SSLConfig.Listen, err)
return
}
go func() {
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen)
if err := api.RunTLS(conf.SSLConfig.Listen,
conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SSLConfig.Listen, err)
}
}()
}
// 如果启用了Unix
if conf.UnixConfig.Listen != "" {
// delete socket file before listening
if _, err := os.Stat(conf.UnixConfig.Listen); err == nil {
if err = os.Remove(conf.UnixConfig.Listen); err != nil {
util.Log().Error("Failed to delete socket file: %s", err)
return
}
}
util.Log().Info("Listening to %q", conf.UnixConfig.Listen)
if err := RunUnix(server); err != nil {
util.Log().Error("Failed to listen to %q: %s", conf.UnixConfig.Listen, err)
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
if err := api.RunUnix(conf.UnixConfig.Listen); err != nil {
util.Log().Error("无法监听[%s]%s", conf.UnixConfig.Listen, err)
}
return
}
util.Log().Info("Listening to %q", conf.SystemConfig.Listen)
server.Addr = conf.SystemConfig.Listen
if err := server.ListenAndServe(); err != nil {
util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err)
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
if err := api.Run(conf.SystemConfig.Listen); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SystemConfig.Listen, err)
}
}
func RunUnix(server *http.Server) error {
listener, err := net.Listen("unix", conf.UnixConfig.Listen)
if err != nil {
return err
}
defer listener.Close()
defer os.Remove(conf.UnixConfig.Listen)
if conf.UnixConfig.Perm > 0 {
err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(conf.UnixConfig.Perm))
if err != nil {
util.Log().Warning(
"Failed to set permission to %q for socket file %q: %s",
conf.UnixConfig.Perm,
conf.UnixConfig.Listen,
err,
)
}
}
return server.Serve(listener)
}

View File

@@ -5,25 +5,20 @@ import (
"context"
"crypto/md5"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
const (
CallbackFailedStatusCode = http.StatusUnauthorized
"github.com/qiniu/api.v7/v7/auth/qbox"
)
// SignRequired 验证请求签名
@@ -122,60 +117,48 @@ func WebDAVAuth() gin.HandlerFunc {
}
}
// 对上传会话进行验证
func UseUploadSession(policyType string) gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp := uploadCallbackCheck(c, policyType)
if resp.Code != 0 {
c.JSON(CallbackFailedStatusCode, resp)
c.Abort()
return
}
c.Next()
}
}
// uploadCallbackCheck 对上传回调请求的 callback key 进行验证,如果成功则返回上传用户
func uploadCallbackCheck(c *gin.Context, policyType string) serializer.Response {
func uploadCallbackCheck(c *gin.Context) (serializer.Response, *model.User) {
// 验证 Callback Key
sessionID := c.Param("sessionID")
if sessionID == "" {
return serializer.ParamErr("Session ID cannot be empty", nil)
callbackKey := c.Param("key")
if callbackKey == "" {
return serializer.ParamErr("Callback Key 不能为空", nil), nil
}
callbackSessionRaw, exist := cache.Get(filesystem.UploadSessionCachePrefix + sessionID)
callbackSessionRaw, exist := cache.Get("callback_" + callbackKey)
if !exist {
return serializer.Err(serializer.CodeUploadSessionExpired, "上传会话不存在或已过期", nil)
return serializer.ParamErr("回调会话不存在或已过期", nil), nil
}
callbackSession := callbackSessionRaw.(serializer.UploadSession)
c.Set(filesystem.UploadSessionCtx, &callbackSession)
if callbackSession.Policy.Type != policyType {
return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
}
c.Set("callbackSession", &callbackSession)
// 清理回调会话
_ = cache.Deletes([]string{sessionID}, filesystem.UploadSessionCachePrefix)
_ = cache.Deletes([]string{callbackKey}, "callback_")
// 查找用户
user, err := model.GetActiveUserByID(callbackSession.UID)
if err != nil {
return serializer.Err(serializer.CodeUserNotFound, "", err)
return serializer.Err(serializer.CodeCheckLogin, "找不到用户", err), nil
}
c.Set(filesystem.UserCtx, &user)
return serializer.Response{}
c.Set("user", &user)
return serializer.Response{}, &user
}
// RemoteCallbackAuth 远程回调签名验证
func RemoteCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp, user := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(200, resp)
c.Abort()
return
}
// 验证签名
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
authInstance := auth.HMACAuth{SecretKey: []byte(session.Policy.SecretKey)}
authInstance := auth.HMACAuth{SecretKey: []byte(user.Policy.SecretKey)}
if err := auth.CheckRequest(authInstance, c.Request); err != nil {
c.JSON(CallbackFailedStatusCode, serializer.Err(serializer.CodeCredentialInvalid, err.Error(), err))
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, err.Error(), err))
c.Abort()
return
}
@@ -188,20 +171,25 @@ func RemoteCallbackAuth() gin.HandlerFunc {
// QiniuCallbackAuth 七牛回调签名验证
func QiniuCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
// 验证回调是否来自qiniu
mac := qbox.NewMac(session.Policy.AccessKey, session.Policy.SecretKey)
ok, err := mac.VerifyCallback(c.Request)
if err != nil {
util.Log().Debug("Failed to verify callback request: %s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
// 验证key并查找用户
resp, user := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
// 验证回调是否来自qiniu
mac := qbox.NewMac(user.Policy.AccessKey, user.Policy.SecretKey)
ok, err := mac.VerifyCallback(c.Request)
if err != nil {
util.Log().Debug("无法验证回调请求,%s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "无法验证回调请求"})
c.Abort()
return
}
if !ok {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Invalid signature."})
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名无效"})
c.Abort()
return
}
@@ -213,10 +201,18 @@ func QiniuCallbackAuth() gin.HandlerFunc {
// OSSCallbackAuth 阿里云OSS回调签名验证
func OSSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
err := oss.VerifyCallbackSignature(c.Request)
if err != nil {
util.Log().Debug("Failed to verify callback request: %s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Failed to verify callback request."})
util.Log().Debug("回调签名验证失败,%s", err)
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "回调签名验证失败"})
c.Abort()
return
}
@@ -228,7 +224,13 @@ func OSSCallbackAuth() gin.HandlerFunc {
// UpyunCallbackAuth 又拍云回调签名验证
func UpyunCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(filesystem.UploadSessionCtx).(*serializer.UploadSession)
// 验证key并查找用户
resp, user := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
// 获取请求正文
body, err := ioutil.ReadAll(c.Request.Body)
@@ -242,7 +244,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
c.Request.Body = ioutil.NopCloser(bytes.NewReader(body))
// 准备验证Upyun回调签名
handler := upyun.Driver{Policy: &session.Policy}
handler := upyun.Driver{Policy: &user.Policy}
contentMD5 := c.Request.Header.Get("Content-Md5")
date := c.Request.Header.Get("Date")
actualSignature := c.Request.Header.Get("Authorization")
@@ -250,7 +252,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
// 计算正文MD5
actualContentMD5 := fmt.Sprintf("%x", md5.Sum(body))
if actualContentMD5 != contentMD5 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5 mismatch."})
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "MD5不一致"})
c.Abort()
return
}
@@ -265,7 +267,7 @@ func UpyunCallbackAuth() gin.HandlerFunc {
// 对比签名
if signature != actualSignature {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "Signature not match"})
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: "鉴权失败"})
c.Abort()
return
}
@@ -275,10 +277,50 @@ func UpyunCallbackAuth() gin.HandlerFunc {
}
// OneDriveCallbackAuth OneDrive回调签名验证
// TODO 解耦
func OneDriveCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
// 发送回调结束信号
mq.GlobalMQ.Publish(c.Param("sessionID"), mq.Message{})
onedrive.FinishCallback(c.Param("key"))
c.Next()
}
}
// COSCallbackAuth 腾讯云COS回调签名验证
// TODO 解耦 测试
func COSCallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
c.Next()
}
}
// S3CallbackAuth Amazon S3回调签名验证
func S3CallbackAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 验证key并查找用户
resp, _ := uploadCallbackCheck(c)
if resp.Code != 0 {
c.JSON(401, serializer.GeneralUploadCallbackFailed{Error: resp.Msg})
c.Abort()
return
}
c.Next()
}
@@ -289,7 +331,7 @@ func IsAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
user, _ := c.Get("user")
if user.(*model.User).Group.ID != 1 && user.(*model.User).ID != 1 {
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "", nil))
c.JSON(200, serializer.Err(serializer.CodeAdminRequired, "您不是管理组成员", nil))
c.Abort()
return
}

View File

@@ -3,24 +3,21 @@ package middleware
import (
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/qiniu/api.v7/v7/auth/qbox"
"github.com/stretchr/testify/assert"
)
@@ -226,31 +223,19 @@ func TestWebDAVAuth(t *testing.T) {
}
func TestUseUploadSession(t *testing.T) {
func TestRemoteCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UseUploadSession("local")
// sessionID 为空
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.True(c.IsAborted())
}
AuthFunc := RemoteCallbackAuth()
// 成功
{
cache.Set(
filesystem.UploadSessionCachePrefix+"testCallBackRemote",
"callback_testCallBackRemote",
serializer.UploadSession{
UID: 1,
PolicyID: 513,
VirtualPath: "/",
Policy: model.Policy{Type: "local"},
},
0,
)
@@ -263,7 +248,7 @@ func TestUseUploadSession(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testCallBackRemote"},
{"key", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
@@ -272,96 +257,80 @@ func TestUseUploadSession(t *testing.T) {
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestUploadCallbackCheck(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
// 上传会话不存在
// Callback Key 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testSessionNotExist"},
{"key", "testCallBackRemote"},
}
res := uploadCallbackCheck(c, "local")
a.Contains("上传会话不存在或已过期", res.Msg)
}
// 上传策略不一致
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testPolicyNotMatch"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
res := uploadCallbackCheck(c, "local")
a.Contains("Policy not supported", res.Msg)
}
// 用户不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testUserNotExist"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testUserNotExist",
serializer.UploadSession{
UID: 313,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
res := uploadCallbackCheck(c, "remote")
a.Contains("找不到用户", res.Msg)
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
a.False(ok)
}
}
func TestRemoteCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := RemoteCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.True(c.IsAborted())
}
// 用户不存在
{
cache.Set(
"callback_testCallBackRemote",
serializer.UploadSession{
UID: 1,
PolicyID: 550,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 签名错误
{
cache.Set(
"callback_testCallBackRemote",
serializer.UploadSession{
UID: 1,
PolicyID: 514,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[514]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Params = []gin.Param{
{"key", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// Callback Key 为空
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
@@ -371,17 +340,39 @@ func TestQiniuCallbackAuth(t *testing.T) {
rec := httptest.NewRecorder()
AuthFunc := QiniuCallbackAuth()
// 成功
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
c.Params = []gin.Param{
{"key", "testQiniuBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testQiniuBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
"callback_testCallBackQiniu",
serializer.UploadSession{
UID: 1,
PolicyID: 515,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[515]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackQiniu"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
@@ -394,21 +385,33 @@ func TestQiniuCallbackAuth(t *testing.T) {
// 验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
cache.Set(
"callback_testCallBackQiniu",
serializer.UploadSession{
UID: 1,
PolicyID: 516,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[516]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackQiniu"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "1213")
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
c.Request.Header["Authorization"] = []string{"QBox " + token + " "}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
}
@@ -418,41 +421,76 @@ func TestOSSCallbackAuth(t *testing.T) {
rec := httptest.NewRecorder()
AuthFunc := OSSCallbackAuth()
// 签名验证失败
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
c.Params = []gin.Param{
{"key", "testOSSBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testQiniuBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 签名验证失败
{
cache.Set(
"callback_testCallBackOSS",
serializer.UploadSession{
UID: 1,
PolicyID: 517,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[517]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackOSS"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
cache.Set(
"callback_TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH",
serializer.UploadSession{
UID: 1,
PolicyID: 518,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[518]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
@@ -469,71 +507,130 @@ func TestUpyunCallbackAuth(t *testing.T) {
rec := httptest.NewRecorder()
AuthFunc := UpyunCallbackAuth()
// 无法获取请求正文
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
c.Params = []gin.Param{
{"key", "testUpyunBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testUpyunBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 无法获取请求正文
{
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 509,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[519]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 正文MD5不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 510,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[520]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"123"}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 签名不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 511,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[521]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 512,
VirtualPath: "/",
},
})
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[522]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
@@ -543,28 +640,87 @@ func TestOneDriveCallbackAuth(t *testing.T) {
rec := httptest.NewRecorder()
AuthFunc := OneDriveCallbackAuth()
// 成功
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "TestOneDriveCallbackAuth"},
{"key", "testUpyunBackRemote"},
}
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testUpyunBackRemote", nil)
AuthFunc(c)
select {
case <-res:
case <-time.After(time.Millisecond * 500):
asserts.Fail("mq message should be published")
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 512,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[657]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestCOSCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := COSCallbackAuth()
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testUpyunBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testUpyunBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 512,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
@@ -603,3 +759,46 @@ func TestIsAdmin(t *testing.T) {
asserts.False(c.IsAborted())
}
}
func TestS3CallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := S3CallbackAuth()
// Callback Key 相关验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testUpyunBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testUpyunBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
"callback_testCallBackUpyun",
serializer.UploadSession{
UID: 1,
PolicyID: 512,
VirtualPath: "/",
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"key", "testCallBackUpyun"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
AuthFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -24,11 +24,6 @@ type req struct {
Randstr string `json:"randstr"`
}
const (
captchaNotMatch = "CAPTCHA not match."
captchaRefresh = "Verification failed, please refresh the page and retry."
)
// CaptchaRequired 验证请求签名
func CaptchaRequired(configName string) gin.HandlerFunc {
return func(c *gin.Context) {
@@ -48,7 +43,7 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
bodyCopy := new(bytes.Buffer)
_, err := io.Copy(bodyCopy, c.Request.Body)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.JSON(200, serializer.ParamErr("验证码错误", err))
c.Abort()
return
}
@@ -56,7 +51,7 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
bodyData := bodyCopy.Bytes()
err = json.Unmarshal(bodyData, &service)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.JSON(200, serializer.ParamErr("验证码错误", err))
c.Abort()
return
}
@@ -67,7 +62,7 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
captchaID := util.GetSession(c, "captchaID")
util.DeleteSession(c, "captchaID")
if captchaID == nil || !base64Captcha.VerifyCaptcha(captchaID.(string), service.CaptchaCode) {
c.JSON(200, serializer.Err(serializer.CodeCaptchaError, captchaNotMatch, err))
c.JSON(200, serializer.ParamErr("验证码错误", nil))
c.Abort()
return
}
@@ -76,15 +71,15 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
case "recaptcha":
reCAPTCHA, err := recaptcha.NewReCAPTCHA(options["captcha_ReCaptchaSecret"], recaptcha.V2, 10*time.Second)
if err != nil {
util.Log().Warning("reCAPTCHA verification failed, %s", err)
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
c.Abort()
break
}
err = reCAPTCHA.Verify(service.CaptchaCode)
if err != nil {
util.Log().Warning("reCAPTCHA verification failed, %s", err)
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
util.Log().Warning("reCAPTCHA 验证错误, %s", err)
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
c.Abort()
return
}
@@ -108,13 +103,13 @@ func CaptchaRequired(configName string) gin.HandlerFunc {
request.UserIp = common.StringPtr(c.ClientIP())
response, err := client.DescribeCaptchaResult(request)
if err != nil {
util.Log().Warning("TCaptcha verification failed, %s", err)
util.Log().Warning("TCaptcha 验证错误, %s", err)
c.Abort()
break
}
if *response.Response.CaptchaCode != int64(1) {
c.JSON(200, serializer.Err(serializer.CodeCaptchaRefreshNeeded, captchaRefresh, nil))
c.JSON(200, serializer.ParamErr("验证失败,请刷新网页后再次验证", nil))
c.Abort()
return
}

View File

@@ -1,7 +1,6 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
@@ -11,9 +10,9 @@ import (
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader(auth.CrHeaderPrefix+"Site-Id"))
c.Set("MasterSiteURL", c.GetHeader(auth.CrHeaderPrefix+"Site-Url"))
c.Set("MasterVersion", c.GetHeader(auth.CrHeaderPrefix+"Cloudreve-Version"))
c.Set("MasterSiteID", c.GetHeader("X-Cr-Site-Id"))
c.Set("MasterSiteURL", c.GetHeader("X-Cr-Site-Url"))
c.Set("MasterVersion", c.GetHeader("X-Cr-Cloudreve-Version"))
c.Next()
}
}
@@ -25,7 +24,7 @@ func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc
// 获取对应主机节点的从机Aria2实例
caller, err := clusterController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "Failed to get Aria2 instance", err))
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
return
}
@@ -35,23 +34,23 @@ func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc
return
}
c.JSON(200, serializer.ParamErr("Unknown master node ID", nil))
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
c.Abort()
}
}
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader(auth.CrHeaderPrefix+"Node-Id"), 10, 64)
nodeID, err := strconv.ParseUint(c.GetHeader("X-Cr-Node-Id"), 10, 64)
if err != nil {
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
slaveNode := nodePool.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("Unknown master node ID", err))
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}

View File

@@ -1,77 +0,0 @@
package middleware
import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
"net/http"
)
// HashID 将给定对象的HashID转换为真实ID
func HashID(IDType int) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Param("id") != "" {
id, err := hashid.DecodeHashID(c.Param("id"), IDType)
if err == nil {
c.Set("object_id", id)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("Failed to parse object ID", nil))
c.Abort()
return
}
c.Next()
}
}
// IsFunctionEnabled 当功能未开启时阻止访问
func IsFunctionEnabled(key string) gin.HandlerFunc {
return func(c *gin.Context) {
if !model.IsTrueVal(model.GetSettingByName(key)) {
c.JSON(200, serializer.Err(serializer.CodeFeatureNotEnabled, "This feature is not enabled", nil))
c.Abort()
return
}
c.Next()
}
}
// CacheControl 屏蔽客户端缓存
func CacheControl() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Cache-Control", "private, no-cache")
}
}
func Sandbox() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Content-Security-Policy", "sandbox")
}
}
// StaticResourceCache 使用静态资源缓存策略
func StaticResourceCache() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400)))
}
}
// MobileRequestOnly
func MobileRequestOnly() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader(auth.CrHeaderPrefix+"ios") == "" {
c.Redirect(http.StatusMovedPermanently, model.GetSiteURL().String())
c.Abort()
return
}
c.Next()
}
}

View File

@@ -1,30 +0,0 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
)
// ValidateSourceLink validates if the perm source link is a valid redirect link
func ValidateSourceLink() gin.HandlerFunc {
return func(c *gin.Context) {
linkID, ok := c.Get("object_id")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink, err := model.GetSourceLinkByID(linkID)
if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink.Downloaded()
c.Set("source_link", sourceLink)
c.Next()
}
}

View File

@@ -1,57 +0,0 @@
package middleware
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestValidateSourceLink(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ValidateSourceLink()
// ID 不存在
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
a.True(c.IsAborted())
}
// SourceLink 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 原文件不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
testFunc(c)
a.False(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -23,13 +23,13 @@ func FrontendFileHandler() gin.HandlerFunc {
// 读取index.html
file, err := bootstrap.StaticFS.Open("/index.html")
if err != nil {
util.Log().Warning("Static file \"index.html\" does not exist, it might affect the display of the homepage.")
util.Log().Warning("静态文件[index.html]不存在,可能会影响首页展示")
return ignoreFunc
}
fileContentBytes, err := ioutil.ReadAll(file)
if err != nil {
util.Log().Warning("Cannot read static file \"index.html\", it might affect the display of the homepage.")
util.Log().Warning("静态文件[index.html]读取失败,可能会影响首页展示")
return ignoreFunc
}
fileContent := string(fileContentBytes)
@@ -39,11 +39,7 @@ func FrontendFileHandler() gin.HandlerFunc {
path := c.Request.URL.Path
// API 跳过
if strings.HasPrefix(path, "/api") ||
strings.HasPrefix(path, "/custom") ||
strings.HasPrefix(path, "/dav") ||
strings.HasPrefix(path, "/f") ||
path == "/manifest.json" {
if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" {
c.Next()
return
}
@@ -66,10 +62,6 @@ func FrontendFileHandler() gin.HandlerFunc {
return
}
if path == "/service-worker.js" {
c.Header("Cache-Control", "public, no-cache")
}
// 存在的静态文件
fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()

40
middleware/option.go Normal file
View File

@@ -0,0 +1,40 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
)
// HashID 将给定对象的HashID转换为真实ID
func HashID(IDType int) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Param("id") != "" {
id, err := hashid.DecodeHashID(c.Param("id"), IDType)
if err == nil {
c.Set("object_id", id)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("无法解析对象ID", nil))
c.Abort()
return
}
c.Next()
}
}
// IsFunctionEnabled 当功能未开启时阻止访问
func IsFunctionEnabled(key string) gin.HandlerFunc {
return func(c *gin.Context) {
if !model.IsTrueVal(model.GetSettingByName(key)) {
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "未开启此功能", nil))
c.Abort()
return
}
c.Next()
}
}

View File

@@ -76,30 +76,3 @@ func TestIsFunctionEnabled(t *testing.T) {
}
}
func TestCacheControl(t *testing.T) {
a := assert.New(t)
TestFunc := CacheControl()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
}
func TestSandbox(t *testing.T) {
a := assert.New(t)
TestFunc := Sandbox()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
}
func TestStaticResourceCache(t *testing.T) {
a := assert.New(t)
TestFunc := StaticResourceCache()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
}

View File

@@ -1,9 +1,6 @@
package middleware
import (
"net/http"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
@@ -23,35 +20,17 @@ func Session(secret string) gin.HandlerFunc {
var err error
Store, err = redis.NewStoreWithDB(10, conf.RedisConfig.Network, conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB, []byte(secret))
if err != nil {
util.Log().Panic("Failed to connect to Redis%s", err)
util.Log().Panic("无法连接到 Redis%s", err)
}
util.Log().Info("Connect to Redis server %q.", conf.RedisConfig.Server)
util.Log().Info("已连接到 Redis 服务器:%s", conf.RedisConfig.Server)
} else {
Store = memstore.NewStore([]byte(secret))
}
sameSiteMode := http.SameSiteDefaultMode
switch strings.ToLower(conf.CORSConfig.SameSite) {
case "default":
sameSiteMode = http.SameSiteDefaultMode
case "none":
sameSiteMode = http.SameSiteNoneMode
case "strict":
sameSiteMode = http.SameSiteStrictMode
case "lax":
sameSiteMode = http.SameSiteLaxMode
}
// Also set Secure: true if using SSL, you should though
Store.Options(sessions.Options{
HttpOnly: true,
MaxAge: 60 * 86400,
Path: "/",
SameSite: sameSiteMode,
Secure: conf.CORSConfig.Secure,
})
// TODO:same-site policy
Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"})
return sessions.Sessions("cloudreve-session", Store)
}
@@ -71,7 +50,7 @@ func CSRFCheck() gin.HandlerFunc {
return
}
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "Invalid origin", nil))
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil))
c.Abort()
}
}

View File

@@ -16,14 +16,14 @@ func ShareOwner() gin.HandlerFunc {
if userCtx, ok := c.Get("user"); ok {
user = userCtx.(*model.User)
} else {
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "", nil))
c.JSON(200, serializer.Err(serializer.CodeCheckLogin, "请先登录", nil))
c.Abort()
return
}
if share, ok := c.Get("share"); ok {
if share.(*model.Share).Creator().ID != user.ID {
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
c.JSON(200, serializer.Err(serializer.CodeNotFound, "分享不存在", nil))
c.Abort()
return
}
@@ -46,7 +46,7 @@ func ShareAvailable() gin.HandlerFunc {
share := model.GetShareByHashID(c.Param("id"))
if share == nil || !share.IsAvailable() {
c.JSON(200, serializer.Err(serializer.CodeShareLinkNotFound, "", nil))
c.JSON(200, serializer.Err(serializer.CodeNotFound, "分享不存在或已失效", nil))
c.Abort()
return
}
@@ -65,7 +65,7 @@ func ShareCanPreview() gin.HandlerFunc {
c.Next()
return
}
c.JSON(200, serializer.Err(serializer.CodeDisabledSharePreview, "",
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "此分享无法预览",
nil))
c.Abort()
return
@@ -85,7 +85,7 @@ func CheckShareUnlocked() gin.HandlerFunc {
unlocked := util.GetSession(c, sessionKey) != nil
if !unlocked {
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr,
"", nil))
"无权访问此分享", nil))
c.Abort()
return
}
@@ -109,7 +109,7 @@ func BeforeShareDownload() gin.HandlerFunc {
// 检查用户是否可以下载此分享的文件
err := share.CanBeDownloadBy(user)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(),
nil))
c.Abort()
return
@@ -118,7 +118,7 @@ func BeforeShareDownload() gin.HandlerFunc {
// 对积分、下载次数进行更新
err = share.DownloadBy(user, c)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, err.Error(),
nil))
c.Abort()
return

View File

@@ -1,70 +0,0 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)
const (
WopiSessionCtx = "wopi_session"
)
// WopiWriteAccess validates if write access is obtained.
func WopiWriteAccess() gin.HandlerFunc {
return func(c *gin.Context) {
session := c.MustGet(WopiSessionCtx).(*wopi.SessionCache)
if session.Action != wopi.ActionEdit {
c.Status(http.StatusNotFound)
c.Header(wopi.ServerErrorHeader, "read-only access")
c.Abort()
return
}
c.Next()
}
}
func WopiAccessValidation(w wopi.Client, store cache.Driver) gin.HandlerFunc {
return func(c *gin.Context) {
accessToken := strings.Split(c.Query(wopi.AccessTokenQuery), ".")
if len(accessToken) != 2 {
c.Status(http.StatusForbidden)
c.Header(wopi.ServerErrorHeader, "malformed access token")
c.Abort()
return
}
sessionRaw, exist := store.Get(wopi.SessionCachePrefix + accessToken[0])
if !exist {
c.Status(http.StatusForbidden)
c.Header(wopi.ServerErrorHeader, "invalid access token")
c.Abort()
return
}
session := sessionRaw.(wopi.SessionCache)
user, err := model.GetActiveUserByID(session.UserID)
if err != nil {
c.Status(http.StatusInternalServerError)
c.Header(wopi.ServerErrorHeader, "user not found")
c.Abort()
return
}
fileID := c.MustGet("object_id").(uint)
if fileID != session.FileID {
c.Status(http.StatusInternalServerError)
c.Header(wopi.ServerErrorHeader, "file not found")
c.Abort()
return
}
c.Set("user", &user)
c.Set(WopiSessionCtx, &session)
c.Next()
}
}

View File

@@ -1,112 +0,0 @@
package middleware
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestWopiWriteAccess(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := WopiWriteAccess()
// deny preview only session
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
testFunc(c)
asserts.True(c.IsAborted())
}
// pass
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestWopiAccessValidation(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
mockWopi := &wopimock.WopiClientMock{}
mockCache := cache.NewMemoStore()
testFunc := WopiAccessValidation(mockWopi, mockCache)
// malformed access token
{
c, _ := gin.CreateTestContext(rec)
c.AddParam(wopi.AccessTokenQuery, "000")
testFunc(c)
asserts.True(c.IsAborted())
}
// session key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
testFunc(c)
asserts.True(c.IsAborted())
}
// user key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// file not found
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(0))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// all pass
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(1))
testFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotPanics(func() {
c.MustGet(WopiSessionCtx)
})
asserts.NotPanics(func() {
c.MustGet("user")
})
}
}

View File

@@ -1,122 +0,0 @@
package model
import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
)
var defaultSettings = []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
{Name: "register_enabled", Value: `1`, Type: "register"},
{Name: "default_group", Value: `2`, Type: "register"},
{Name: "siteKeywords", Value: `Cloudreve, cloud storage`, Type: "basic"},
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
{Name: "siteTitle", Value: `Inclusive cloud storage for everyone`, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: uuid.Must(uuid.NewV4()).String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpHost", Value: `smtp.mxhichina.com`, Type: "mail"},
{Name: "smtpPort", Value: `25`, Type: "mail"},
{Name: "replyTo", Value: `abslant@126.com`, Type: "mail"},
{Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpPass", Value: ``, Type: "mail"},
{Name: "smtpEncryption", Value: `0`, Type: "mail"},
{Name: "maxEditSize", Value: `52428800`, Type: "file_edit"},
{Name: "archive_timeout", Value: `600`, Type: "timeout"},
{Name: "download_timeout", Value: `600`, Type: "timeout"},
{Name: "preview_timeout", Value: `600`, Type: "timeout"},
{Name: "doc_preview_timeout", Value: `600`, Type: "timeout"},
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
{Name: "slave_node_retry", Value: `3`, Type: "slave"},
{Name: "slave_ping_interval", Value: `60`, Type: "slave"},
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
{Name: "chunk_retries", Value: `5`, Type: "retry"},
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
{Name: "use_temp_chunk_buffer", Value: `1`, Type: "upload"},
{Name: "login_captcha", Value: `0`, Type: "login"},
{Name: "reg_captcha", Value: `0`, Type: "login"},
{Name: "email_active", Value: `0`, Type: "register"},
{Name: "mail_activation_template", Value: `<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #009688; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">激活{siteTitle}账户</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您注册{siteTitle},请点击下方按钮完成账户激活。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{activationUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #009688; margin: 0; border-color: #009688; border-style: solid; border-width: 10px 20px;">激活账户</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>`, Type: "mail_template"},
{Name: "forget_captcha", Value: `0`, Type: "login"},
{Name: "mail_reset_pwd_template", Value: `<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>重设密码</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>`, Type: "mail_template"},
{Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"},
{Name: "hot_share_num", Value: `10`, Type: "share"},
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
{Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
{Name: "temp_path", Value: "temp", Type: "path"},
{Name: "avatar_path", Value: "avatar", Type: "path"},
{Name: "avatar_size", Value: "2097152", Type: "avatar"},
{Name: "avatar_size_l", Value: "200", Type: "avatar"},
{Name: "avatar_size_m", Value: "130", Type: "avatar"},
{Name: "avatar_size_s", Value: "50", Type: "avatar"},
{Name: "home_view_method", Value: "icon", Type: "view"},
{Name: "share_view_method", Value: "list", Type: "view"},
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
{Name: "cron_recycle_upload_session", Value: "@every 1h30m", Type: "cron"},
{Name: "authn_enabled", Value: "0", Type: "authn"},
{Name: "captcha_type", Value: "normal", Type: "captcha"},
{Name: "captcha_height", Value: "60", Type: "captcha"},
{Name: "captcha_width", Value: "240", Type: "captcha"},
{Name: "captcha_mode", Value: "3", Type: "captcha"},
{Name: "captcha_ComplexOfNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_ComplexOfNoiseDot", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowHollowLine", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowNoiseDot", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"},
{Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"},
{Name: "captcha_ReCaptchaKey", Value: "defaultKey", Type: "captcha"},
{Name: "captcha_ReCaptchaSecret", Value: "defaultSecret", Type: "captcha"},
{Name: "captcha_TCaptcha_CaptchaAppId", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_AppSecretKey", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_SecretId", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_SecretKey", Value: "", Type: "captcha"},
{Name: "thumb_width", Value: "400", Type: "thumb"},
{Name: "thumb_height", Value: "300", Type: "thumb"},
{Name: "thumb_file_suffix", Value: "._thumb", Type: "thumb"},
{Name: "thumb_max_task_count", Value: "-1", Type: "thumb"},
{Name: "thumb_encode_method", Value: "jpg", Type: "thumb"},
{Name: "thumb_gc_after_gen", Value: "0", Type: "thumb"},
{Name: "thumb_encode_quality", Value: "85", Type: "thumb"},
{Name: "pwa_small_icon", Value: "/static/img/favicon.ico", Type: "pwa"},
{Name: "pwa_medium_icon", Value: "/static/img/logo192.png", Type: "pwa"},
{Name: "pwa_large_icon", Value: "/static/img/logo512.png", Type: "pwa"},
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
{Name: "show_app_promotion", Value: "1", Type: "mobile"},
{Name: "public_resource_maxage", Value: "86400", Type: "timeout"},
{Name: "wopi_enabled", Value: "0", Type: "wopi"},
{Name: "wopi_endpoint", Value: "", Type: "wopi"},
{Name: "wopi_max_size", Value: "52428800", Type: "wopi"},
{Name: "wopi_session_timeout", Value: "36000", Type: "wopi"},
}

View File

@@ -1,288 +0,0 @@
package dialects
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/jinzhu/gorm"
)
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db gorm.SQLCommon
DefaultForeignKeyNamer
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db gorm.SQLCommon) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (commonDialect) DefaultValueStr() string {
return "DEFAULT VALUES"
}
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
return keyName
}
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
return indexName, columnName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
type sqlite struct {
commonDialect
}
func init() {
gorm.RegisterDialect("sqlite", &sqlite{})
}
func (sqlite) GetName() string {
return "sqlite"
}
// Get Data Type for Sqlite Dialect
func (s *sqlite) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@@ -32,7 +32,6 @@ type Download struct {
// 数据库忽略字段
StatusInfo rpc.StatusInfo `gorm:"-"`
Task *Task `gorm:"-"`
NodeName string `gorm:"-"`
}
// AfterFind 找到下载任务后的钩子处理Status结构
@@ -61,7 +60,7 @@ func (task *Download) BeforeSave() (err error) {
// Create 创建离线下载记录
func (task *Download) Create() (uint, error) {
if err := DB.Create(task).Error; err != nil {
util.Log().Warning("Failed to insert download record: %s", err)
util.Log().Warning("无法插入离线下载记录, %s", err)
return 0, err
}
return task.ID, nil
@@ -70,7 +69,7 @@ func (task *Download) Create() (uint, error) {
// Save 更新
func (task *Download) Save() error {
if err := DB.Save(task).Error; err != nil {
util.Log().Warning("Failed to update download record: %s", err)
util.Log().Warning("无法更新离线下载记录, %s", err)
return err
}
return nil

View File

@@ -2,9 +2,6 @@ package model
import (
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"path"
"time"
@@ -16,22 +13,19 @@ import (
type File struct {
// 表字段
gorm.Model
Name string `gorm:"unique_index:idx_only_one"`
SourceName string `gorm:"type:text"`
UserID uint `gorm:"index:user_id;unique_index:idx_only_one"`
Size uint64
PicInfo string
FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"`
PolicyID uint
UploadSessionID *string `gorm:"index:session_id;unique_index:session_only_one"`
Metadata string `gorm:"type:text"`
Name string `gorm:"unique_index:idx_only_one"`
SourceName string `gorm:"type:text"`
UserID uint `gorm:"index:user_id;unique_index:idx_only_one"`
Size uint64
PicInfo string
FolderID uint `gorm:"index:folder_id;unique_index:idx_only_one"`
PolicyID uint
// 关联模型
Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
// 数据库忽略字段
Position string `gorm:"-"`
MetadataSerialized map[string]string `gorm:"-"`
Position string `gorm:"-"`
}
func init() {
@@ -40,40 +34,12 @@ func init() {
}
// Create 创建文件记录
func (file *File) Create() error {
tx := DB.Begin()
if err := tx.Create(file).Error; err != nil {
util.Log().Warning("Failed to insert file record: %s", err)
tx.Rollback()
return err
func (file *File) Create() (uint, error) {
if err := DB.Create(file).Error; err != nil {
util.Log().Warning("无法插入文件记录, %s", err)
return 0, err
}
user := &User{}
user.ID = file.UserID
if err := user.ChangeStorage(tx, "+", file.Size); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// AfterFind 找到文件后的钩子
func (file *File) AfterFind() (err error) {
// 反序列化文件元数据
if file.Metadata != "" {
err = json.Unmarshal([]byte(file.Metadata), &file.MetadataSerialized)
}
return
}
// BeforeSave Save策略前的钩子
func (file *File) BeforeSave() (err error) {
metaValue, err := json.Marshal(&file.MetadataSerialized)
file.Metadata = string(metaValue)
return err
return file.ID, nil
}
// GetChildFile 查找目录下名为name的子文件
@@ -103,23 +69,19 @@ func (folder *Folder) GetChildFiles() ([]File, error) {
// GetFilesByIDs 根据文件ID批量获取文件,
// UID为0表示忽略用户只根据文件ID检索
func GetFilesByIDs(ids []uint, uid uint) ([]File, error) {
return GetFilesByIDsFromTX(DB, ids, uid)
}
func GetFilesByIDsFromTX(tx *gorm.DB, ids []uint, uid uint) ([]File, error) {
var files []File
var result *gorm.DB
if uid == 0 {
result = tx.Where("id in (?)", ids).Find(&files)
result = DB.Where("id in (?)", ids).Find(&files)
} else {
result = tx.Where("id in (?) AND user_id = ?", ids, uid).Find(&files)
result = DB.Where("id in (?) AND user_id = ?", ids, uid).Find(&files)
}
return files, result.Error
}
// GetFilesByKeywords 根据关键字搜索文件,
// UID为0表示忽略用户只根据文件ID检索. 如果 parents 非空, 则只限制在 parent 包含的目录下搜索
func GetFilesByKeywords(uid uint, parents []uint, keywords ...interface{}) ([]File, error) {
// UID为0表示忽略用户只根据文件ID检索
func GetFilesByKeywords(uid uint, keywords ...interface{}) ([]File, error) {
var (
files []File
result = DB
@@ -137,11 +99,6 @@ func GetFilesByKeywords(uid uint, parents []uint, keywords ...interface{}) ([]Fi
if uid != 0 {
result = result.Where("user_id = ?", uid)
}
if len(parents) > 0 {
result = result.Where("folder_id in (?)", parents)
}
result = result.Where("("+conditions+")", keywords...).Find(&files)
return files, result.Error
@@ -149,7 +106,7 @@ func GetFilesByKeywords(uid uint, parents []uint, keywords ...interface{}) ([]Fi
// GetChildFilesOfFolders 批量检索目录子文件
func GetChildFilesOfFolders(folders *[]Folder) ([]File, error) {
// 将所有待检索目录ID抽离以便检索文件
// 将所有待删除目录ID抽离以便检索文件
folderIDs := make([]uint, 0, len(*folders))
for _, value := range *folders {
folderIDs = append(folderIDs, value.ID)
@@ -161,19 +118,6 @@ func GetChildFilesOfFolders(folders *[]Folder) ([]File, error) {
return files, result.Error
}
// GetUploadPlaceholderFiles 获取所有上传占位文件
// UID为0表示忽略用户
func GetUploadPlaceholderFiles(uid uint) []*File {
query := DB
if uid != 0 {
query = query.Where("user_id = ?", uid)
}
var files []*File
query.Where("upload_session_id is not NULL").Find(&files)
return files
}
// GetPolicy 获取文件所属策略
func (file *File) GetPolicy() *Policy {
if file.Policy.Model.ID == 0 {
@@ -187,20 +131,15 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) {
// 结果值
filteredFiles := make([]File, 0)
if len(files) == 0 {
return filteredFiles, nil
}
// 查询软链接的文件
filesWithSoftLinks := make([]File, 0)
for _, file := range files {
var softLinkFile File
res := DB.
Where("source_name = ? and policy_id = ? and id != ?", file.SourceName, file.PolicyID, file.ID).
First(&softLinkFile)
if res.Error == nil {
filesWithSoftLinks = append(filesWithSoftLinks, softLinkFile)
}
var filesWithSoftLinks []File
tx := DB
for _, value := range files {
tx = tx.Or("source_name = ? and policy_id = ? and id != ?", value.SourceName, value.PolicyID, value.ID)
}
result := tx.Find(&filesWithSoftLinks)
if result.Error != nil {
return nil, result.Error
}
// 过滤具有软连接的文件
@@ -227,40 +166,10 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) {
}
// DeleteFiles 批量删除文件记录并归还容量
func DeleteFiles(files []*File, uid uint) error {
tx := DB.Begin()
user := &User{}
user.ID = uid
var size uint64
for _, file := range files {
if uid > 0 && file.UserID != uid {
tx.Rollback()
return errors.New("user id not consistent")
}
result := tx.Unscoped().Where("size = ?", file.Size).Delete(file)
if result.Error != nil {
tx.Rollback()
return result.Error
}
if result.RowsAffected == 0 {
tx.Rollback()
return errors.New("file size is dirty")
}
size += file.Size
}
if uid > 0 {
if err := user.ChangeStorage(tx, "-", size); err != nil {
tx.Rollback()
return err
}
}
return tx.Commit().Error
// DeleteFileByIDs 根据给定ID批量删除文件记录
func DeleteFileByIDs(ids []uint) error {
result := DB.Where("id in (?)", ids).Unscoped().Delete(&File{})
return result.Error
}
// GetFilesByParentIDs 根据父目录ID查找文件
@@ -270,53 +179,19 @@ func GetFilesByParentIDs(ids []uint, uid uint) ([]File, error) {
return files, result.Error
}
// GetFilesByUploadSession 查找上传会话对应的文件
func GetFilesByUploadSession(sessionID string, uid uint) (*File, error) {
file := File{}
result := DB.Where("user_id = ? and upload_session_id = ?", uid, sessionID).Find(&file)
return &file, result.Error
}
// Rename 重命名文件
func (file *File) Rename(new string) error {
return DB.Model(&file).UpdateColumn("name", new).Error
return DB.Model(&file).Update("name", new).Error
}
// UpdatePicInfo 更新文件的图像信息
func (file *File) UpdatePicInfo(value string) error {
return DB.Model(&file).Set("gorm:association_autoupdate", false).UpdateColumns(File{PicInfo: value}).Error
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("pic_info", value).Error
}
// UpdateSize 更新文件的大小信息
// TODO: 全局锁
func (file *File) UpdateSize(value uint64) error {
tx := DB.Begin()
var sizeDelta uint64
operator := "+"
user := User{}
user.ID = file.UserID
if value > file.Size {
sizeDelta = value - file.Size
} else {
operator = "-"
sizeDelta = file.Size - value
}
if res := tx.Model(&file).
Where("size = ?", file.Size).
Set("gorm:association_autoupdate", false).
Update("size", value); res.Error != nil {
tx.Rollback()
return res.Error
}
if err := user.ChangeStorage(tx, operator, sizeDelta); err != nil {
tx.Rollback()
return err
}
file.Size = value
return tx.Commit().Error
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("size", value).Error
}
// UpdateSourceName 更新文件的源文件名
@@ -324,43 +199,6 @@ func (file *File) UpdateSourceName(value string) error {
return DB.Model(&file).Set("gorm:association_autoupdate", false).Update("source_name", value).Error
}
func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error {
file.UploadSessionID = nil
if lastModified != nil {
file.UpdatedAt = *lastModified
}
return DB.Model(file).UpdateColumns(map[string]interface{}{
"upload_session_id": file.UploadSessionID,
"updated_at": file.UpdatedAt,
"pic_info": picInfo,
}).Error
}
// CanCopy 返回文件是否可被复制
func (file *File) CanCopy() bool {
return file.UploadSessionID == nil
}
// CreateOrGetSourceLink creates a SourceLink model. If the given model exists, the existing
// model will be returned.
func (file *File) CreateOrGetSourceLink() (*SourceLink, error) {
res := &SourceLink{}
err := DB.Set("gorm:auto_preload", true).Where("file_id = ?", file.ID).Find(&res).Error
if err == nil && res.ID > 0 {
return res, nil
}
res.FileID = file.ID
res.Name = file.Name
if err := DB.Save(res).Error; err != nil {
return nil, fmt.Errorf("failed to insert SourceLink: %w", err)
}
res.File = *file
return res, nil
}
/*
实现 webdav.FileInfo 接口
*/

View File

@@ -2,12 +2,11 @@ package model
import (
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestFile_Create(t *testing.T) {
@@ -16,62 +15,22 @@ func TestFile_Create(t *testing.T) {
Name: "123",
}
// 无法插入文件记录
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := file.Create()
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
fileID, err := file.Create()
asserts.NoError(err)
asserts.Equal(uint(5), fileID)
asserts.Equal(uint(5), file.ID)
asserts.NoError(mock.ExpectationsWereMet())
// 无法更新用户容量
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := file.Create()
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
fileID, err = file.Create()
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
err := file.Create()
asserts.NoError(err)
asserts.Equal(uint(5), file.ID)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestFile_AfterFind(t *testing.T) {
a := assert.New(t)
file := File{
Name: "123",
Metadata: "{\"name\":\"123\"}",
}
a.NoError(file.AfterFind())
a.Equal("123", file.MetadataSerialized["name"])
}
func TestFile_BeforeSave(t *testing.T) {
a := assert.New(t)
file := File{
Name: "123",
MetadataSerialized: map[string]string{
"name": "123",
},
}
a.NoError(file.BeforeSave())
a.Equal("{\"name\":\"123\"}", file.Metadata)
}
func TestFolder_GetChildFile(t *testing.T) {
@@ -216,17 +175,6 @@ func TestGetChildFilesOfFolders(t *testing.T) {
}
}
func TestGetUploadPlaceholderFiles(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)upload_session_id(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
files := GetUploadPlaceholderFiles(1)
a.NoError(mock.ExpectationsWereMet())
a.Len(files, 1)
}
func TestFile_GetPolicy(t *testing.T) {
asserts := assert.New(t)
@@ -257,19 +205,6 @@ func TestFile_GetPolicy(t *testing.T) {
}
}
func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) {
asserts := assert.New(t)
// 传入空
{
mock.ExpectQuery("SELECT(.+)files(.+)")
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.Error(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(len(file), 0)
DB.Find(&File{})
}
}
func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts := assert.New(t)
files := []File{
@@ -285,34 +220,30 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
},
}
// 传入空文件列表
{
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.NoError(err)
asserts.Empty(file)
}
// 全都没有
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(files, file)
}
// 查询出错
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnError(errors.New("error"))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Nil(file)
}
// 第二个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
@@ -322,18 +253,14 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts.NoError(err)
asserts.Equal(files[:1], file)
}
// 第一个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
@@ -342,16 +269,11 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
// 全部是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
AddRow(3, 24, "2.txt").
AddRow(4, 23, "1.txt"),
)
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
@@ -360,75 +282,28 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
}
}
func TestDeleteFiles(t *testing.T) {
a := assert.New(t)
func TestDeleteFileByIDs(t *testing.T) {
asserts := assert.New(t)
// uid 不一致
{
err := DeleteFiles([]*File{{UserID: 2}}, 1)
a.Contains("user id not consistent", err.Error())
}
// 删除失败
// 出错
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := DeleteFiles([]*File{{UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
err := DeleteFileByIDs([]uint{1, 2, 3})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 无法变更用户容量
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := DeleteFiles([]*File{{UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
}
// 文件脏读
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(1, 0))
mock.ExpectRollback()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
a.Contains("file size is dirty", err.Error())
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WithArgs(uint64(3), sqlmock.AnyArg(), uint(1)).WillReturnResult(sqlmock.NewResult(1, 1))
WillReturnResult(sqlmock.NewResult(0, 3))
mock.ExpectCommit()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.NoError(err)
}
// 成功, 关联用户不存在
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 0)
a.NoError(mock.ExpectationsWereMet())
a.NoError(err)
err := DeleteFileByIDs([]uint{1, 2, 3})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
@@ -449,27 +324,13 @@ func TestGetFilesByParentIDs(t *testing.T) {
asserts.Len(files, 3)
}
func TestGetFilesByUploadSession(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, "sessionID").
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).AddRow(4, "4.txt"))
files, err := GetFilesByUploadSession("sessionID", 1)
a.NoError(err)
a.NoError(mock.ExpectationsWereMet())
a.Equal("4.txt", files.Name)
}
func TestFile_Updates(t *testing.T) {
asserts := assert.New(t)
file := File{Model: gorm.Model{ID: 1}}
// rename
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("newName", 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)").WithArgs("newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.Rename("newName")
asserts.NoError(mock.ExpectationsWereMet())
@@ -479,91 +340,22 @@ func TestFile_Updates(t *testing.T) {
// UpdatePicInfo
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)").WithArgs(10, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.UpdateSize(10)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// UpdatePicInfo
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.UpdatePicInfo("1,1")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// UpdateSourceName
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.UpdateSourceName("newName")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
func TestFile_UpdateSize(t *testing.T) {
a := assert.New(t)
// 增加成功
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(11, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)+(.+)").WithArgs(uint64(1), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.UpdateSize(11))
a.NoError(mock.ExpectationsWereMet())
}
// 减少成功
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
// 文件更新失败
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(8, sqlmock.AnyArg(), 10).WillReturnError(errors.New("error"))
mock.ExpectRollback()
a.Error(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
// 用户容量更新失败
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnError(errors.New("error"))
mock.ExpectRollback()
a.Error(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
}
func TestFile_PopChunkToFile(t *testing.T) {
a := assert.New(t)
timeNow := time.Now()
file := File{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.PopChunkToFile(&timeNow, "1,1"))
}
func TestFile_CanCopy(t *testing.T) {
a := assert.New(t)
file := File{}
a.True(file.CanCopy())
file.UploadSessionID = &file.Name
a.False(file.CanCopy())
}
func TestFile_FileInfoInterface(t *testing.T) {
@@ -600,7 +392,7 @@ func TestGetFilesByKeywords(t *testing.T) {
// 未指定用户
{
mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(0, nil, "k1", "k2")
res, err := GetFilesByKeywords(0, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
@@ -609,59 +401,9 @@ func TestGetFilesByKeywords(t *testing.T) {
// 指定用户
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(1, nil, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
// 指定父目录
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 12, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(1, []uint{12}, "k1", "k2")
res, err := GetFilesByKeywords(1, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
}
func TestFile_CreateOrGetSourceLink(t *testing.T) {
a := assert.New(t)
file := &File{}
file.ID = 1
// 已存在,返回老的 SourceLink
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
// 不存在,插入失败
{
expectedErr := errors.New("error")
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr)
mock.ExpectRollback()
res, err := file.CreateOrGetSourceLink()
a.Nil(res)
a.ErrorIs(err, expectedErr)
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.EqualValues(file.ID, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
}

View File

@@ -23,12 +23,10 @@ type Folder struct {
// Create 创建目录
func (folder *Folder) Create() (uint, error) {
if err := DB.FirstOrCreate(folder, *folder).Error; err != nil {
folder.Model = gorm.Model{}
err2 := DB.First(folder, *folder).Error
return folder.ID, err2
if err := DB.Create(folder).Error; err != nil {
util.Log().Warning("无法插入目录记录, %s", err)
return 0, err
}
return folder.ID, nil
}
@@ -160,11 +158,6 @@ func (folder *Folder) MoveOrCopyFileTo(files []uint, dstFolder *Folder, isCopy b
// 复制文件记录
for _, oldFile := range originFiles {
if !oldFile.CanCopy() {
util.Log().Warning("Cannot copy file %q because it's being uploaded now, skipping...", oldFile.Name)
continue
}
oldFile.Model = gorm.Model{}
oldFile.FolderID = dstFolder.ID
oldFile.UserID = dstFolder.OwnerID
@@ -224,8 +217,8 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
} else if IDCache, ok := newIDCache[*folder.ParentID]; ok {
newID = IDCache
} else {
util.Log().Warning("Failed to get parent folder %q", *folder.ParentID)
return size, errors.New("Failed to get parent folder")
util.Log().Warning("无法取得新的父目录:%d", folder.ParentID)
return size, errors.New("无法取得新的父目录")
}
// 插入新的目录记录
@@ -253,11 +246,6 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
// 复制文件记录
for _, oldFile := range originFiles {
if !oldFile.CanCopy() {
util.Log().Warning("Cannot copy file %q because it's being uploaded now, skipping...", oldFile.Name)
continue
}
oldFile.Model = gorm.Model{}
oldFile.FolderID = newIDCache[oldFile.FolderID]
oldFile.UserID = dstFolder.OwnerID
@@ -275,13 +263,6 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
// MoveFolderTo 将folder目录下的dirs子目录复制或移动到dstFolder
// 返回此过程中增加的容量
func (folder *Folder) MoveFolderTo(dirs []uint, dstFolder *Folder) error {
// 如果目标位置为待移动的目录,会导致 parent 为自己
// 造成死循环且无法被除搜索以外的组件展示
if folder.OwnerID == dstFolder.OwnerID && util.ContainsUint(dirs, dstFolder.ID) {
return errors.New("cannot move a folder into itself")
}
// 更改顶级要移动目录的父目录指向
err := DB.Model(Folder{}).Where(
"id in (?) and owner_id = ? and parent_id = ?",
@@ -298,7 +279,10 @@ func (folder *Folder) MoveFolderTo(dirs []uint, dstFolder *Folder) error {
// Rename 重命名目录
func (folder *Folder) Rename(new string) error {
return DB.Model(&folder).UpdateColumn("name", new).Error
if err := DB.Model(&folder).Update("name", new).Error; err != nil {
return err
}
return nil
}
/*

View File

@@ -17,8 +17,7 @@ func TestFolder_Create(t *testing.T) {
Name: "new folder",
}
// 不存在,插入成功
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
// 插入成功
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
@@ -28,21 +27,12 @@ func TestFolder_Create(t *testing.T) {
asserts.NoError(mock.ExpectationsWereMet())
// 插入失败
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
fid, err = folder.Create()
asserts.NoError(err)
asserts.Equal(uint(1), fid)
asserts.NoError(mock.ExpectationsWereMet())
// 存在,直接返回
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5))
fid, err = folder.Create()
asserts.NoError(err)
asserts.Equal(uint(5), fid)
asserts.Error(err)
asserts.Equal(uint(0), fid)
asserts.NoError(mock.ExpectationsWereMet())
}
@@ -106,7 +96,7 @@ func TestFolder_GetChildFolder(t *testing.T) {
}
func TestGetRecursiveChildFolderSQLite(t *testing.T) {
conf.DatabaseConfig.Type = "sqlite"
conf.DatabaseConfig.Type = "sqlite3"
asserts := assert.New(t)
// 测试目录结构
@@ -222,14 +212,12 @@ func TestFolder_MoveOrCopyFileTo(t *testing.T) {
WithArgs(
1,
2,
3,
1,
1,
).WillReturnRows(
sqlmock.NewRows([]string{"id", "size", "upload_session_id"}).
AddRow(1, 10, nil).
AddRow(2, 20, nil).
AddRow(2, 20, &folder.Name),
sqlmock.NewRows([]string{"id", "size"}).
AddRow(1, 10).
AddRow(2, 20),
)
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
@@ -238,7 +226,7 @@ func TestFolder_MoveOrCopyFileTo(t *testing.T) {
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2, 3},
[]uint{1, 2},
&dstFolder,
true,
)
@@ -347,7 +335,7 @@ func TestFolder_CopyFolderTo(t *testing.T) {
// 测试复制目录结构
// test(2)(5)
// 1(3)(6) 2.txt
// 3(4)(7) 4.txt 5.txt(上传中)
// 3(4)(7) 4.txt
// 正常情况 成功
{
@@ -372,10 +360,9 @@ func TestFolder_CopyFolderTo(t *testing.T) {
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 4).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "folder_id", "size", "upload_session_id"}).
AddRow(1, "2.txt", 2, 10, nil).
AddRow(2, "3.txt", 3, 20, nil).
AddRow(3, "5.txt", 3, 20, &dstFolder.Name),
sqlmock.NewRows([]string{"id", "name", "folder_id", "size"}).
AddRow(1, "2.txt", 2, 10).
AddRow(2, "3.txt", 3, 20),
)
// 复制子文件
@@ -506,8 +493,7 @@ func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) {
}
// 目标目录
dstFolder := Folder{
Model: gorm.Model{ID: 10},
OwnerID: 1,
Model: gorm.Model{ID: 10},
}
// 成功
@@ -521,12 +507,6 @@ func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) {
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 移动自己到自己内部,失败
{
err := parFolder.MoveFolderTo([]uint{10, 2}, &dstFolder)
asserts.Error(err)
}
}
func TestFolder_FileInfoInterface(t *testing.T) {
@@ -584,39 +564,3 @@ func TestTraceRoot(t *testing.T) {
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestFolder_Rename(t *testing.T) {
asserts := assert.New(t)
folder := Folder{
Model: gorm.Model{
ID: 1,
},
Name: "test_name",
OwnerID: 1,
Position: "/test",
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
WithArgs("test_name_new", 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := folder.Rename("test_name_new")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 出现错误
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
WithArgs("test_name_new", 1).
WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := folder.Rename("test_name_new")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}

View File

@@ -14,7 +14,7 @@ type Group struct {
ShareEnabled bool
WebDAVEnabled bool
SpeedLimit int
Options string `json:"-" gorm:"size:4294967295"`
Options string `json:"-",gorm:"type:text"`
// 数据库忽略字段
PolicyList []uint `gorm:"-"`
@@ -23,18 +23,14 @@ type Group struct {
// GroupOption 用户组其他配置
type GroupOption struct {
ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载
ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩
CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
AdvanceDelete bool `json:"advance_delete,omitempty"`
ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载
ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩
CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
}
// GetGroupByID 用ID获取用户组
@@ -68,7 +64,7 @@ func (group *Group) BeforeSave() (err error) {
return err
}
// SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段
//SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段
// TODO 完善测试
func (group *Group) SerializePolicyList() (err error) {
policies, err := json.Marshal(&group.PolicyList)

View File

@@ -9,11 +9,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
_ "github.com/cloudreve/Cloudreve/v3/models/dialects"
_ "github.com/glebarez/go-sqlite"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
// DB 数据库链接单例
@@ -21,59 +20,44 @@ var DB *gorm.DB
// Init 初始化 MySQL 链接
func Init() {
util.Log().Info("Initializing database connection...")
util.Log().Info("初始化数据库连接")
var (
db *gorm.DB
err error
confDBType string = conf.DatabaseConfig.Type
db *gorm.DB
err error
)
// 兼容已有配置中的 "sqlite3" 配置项
if confDBType == "sqlite3" {
confDBType = "sqlite"
}
if gin.Mode() == gin.TestMode {
// 测试模式下,使用内存数据库
db, err = gorm.Open("sqlite", ":memory:")
db, err = gorm.Open("sqlite3", ":memory:")
} else {
switch confDBType {
case "UNSET", "sqlite":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite 数据库
db, err = gorm.Open("sqlite", util.RelativePath(conf.DatabaseConfig.DBFile))
switch conf.DatabaseConfig.Type {
case "UNSET", "sqlite", "sqlite3":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite3 数据库
db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile))
case "postgres":
db, err = gorm.Open(confDBType, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Port))
case "mysql", "mssql":
var host string
if conf.DatabaseConfig.UnixSocket {
host = fmt.Sprintf("unix(%s)",
conf.DatabaseConfig.Host)
} else {
host = fmt.Sprintf("(%s:%d)",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.Port)
}
db, err = gorm.Open(confDBType, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
host,
conf.DatabaseConfig.Host,
conf.DatabaseConfig.Port,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Charset))
default:
util.Log().Panic("Unsupported database type %q.", confDBType)
util.Log().Panic("不支持数据库类型: %s", conf.DatabaseConfig.Type)
}
}
//db.SetLogger(util.Log())
if err != nil {
util.Log().Panic("Failed to connect to database: %s", err)
util.Log().Panic("连接数据库不成功, %s", err)
}
// 处理表前缀
@@ -89,13 +73,10 @@ func Init() {
}
//设置连接池
//空闲
db.DB().SetMaxIdleConns(50)
if confDBType == "sqlite" || confDBType == "UNSET" {
db.DB().SetMaxOpenConns(1)
} else {
db.DB().SetMaxOpenConns(100)
}
//打开
db.DB().SetMaxOpenConns(100)
//超时
db.DB().SetConnMaxLifetime(time.Second * 30)

View File

@@ -7,6 +7,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/fatih/color"
"github.com/gofrs/uuid"
"github.com/hashicorp/go-version"
"github.com/jinzhu/gorm"
"sort"
@@ -19,16 +20,16 @@ func needMigration() bool {
return DB.Where("name = ?", "db_version_"+conf.RequiredDBVersion).First(&setting).Error != nil
}
// 执行数据迁移
//执行数据迁移
func migration() {
// 确认是否需要执行迁移
if !needMigration() {
util.Log().Info("Database version fulfilled, skip schema migration.")
util.Log().Info("数据库版本匹配,跳过数据库迁移")
return
}
util.Log().Info("Start initializing database schema...")
util.Log().Info("开始进行数据库初始化...")
// 清除所有缓存
if instance, ok := cache.Store.(*cache.RedisStore); ok {
@@ -41,7 +42,7 @@ func migration() {
}
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{})
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{})
// 创建初始存储策略
addDefaultPolicy()
@@ -61,7 +62,7 @@ func migration() {
// 执行数据库升级脚本
execUpgradeScripts()
util.Log().Info("Finish initializing database schema.")
util.Log().Info("数据库初始化结束")
}
@@ -70,24 +71,127 @@ func addDefaultPolicy() {
// 未找到初始存储策略时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultPolicy := Policy{
Name: "Default storage policy",
Name: "默认存储策略",
Type: "local",
MaxSize: 0,
AutoRename: true,
DirNameRule: "uploads/{uid}/{path}",
FileNameRule: "{uid}_{randomkey8}_{originname}",
IsOriginLinkEnable: false,
OptionsSerialized: PolicyOption{
ChunkSize: 25 << 20, // 25MB
},
}
if err := DB.Create(&defaultPolicy).Error; err != nil {
util.Log().Panic("Failed to create default storage policy: %s", err)
util.Log().Panic("无法创建初始存储策略, %s", err)
}
}
}
func addDefaultSettings() {
siteID, _ := uuid.NewV4()
defaultSettings := []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
{Name: "siteICPId", Value: ``, Type: "basic"},
{Name: "register_enabled", Value: `1`, Type: "register"},
{Name: "default_group", Value: `2`, Type: "register"},
{Name: "siteKeywords", Value: `网盘,网盘`, Type: "basic"},
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: siteID.String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpHost", Value: `smtp.mxhichina.com`, Type: "mail"},
{Name: "smtpPort", Value: `25`, Type: "mail"},
{Name: "replyTo", Value: `abslant@126.com`, Type: "mail"},
{Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpPass", Value: ``, Type: "mail"},
{Name: "smtpEncryption", Value: `0`, Type: "mail"},
{Name: "maxEditSize", Value: `4194304`, Type: "file_edit"},
{Name: "archive_timeout", Value: `60`, Type: "timeout"},
{Name: "download_timeout", Value: `60`, Type: "timeout"},
{Name: "preview_timeout", Value: `60`, Type: "timeout"},
{Name: "doc_preview_timeout", Value: `60`, Type: "timeout"},
{Name: "upload_credential_timeout", Value: `1800`, Type: "timeout"},
{Name: "upload_session_timeout", Value: `86400`, Type: "timeout"},
{Name: "slave_api_timeout", Value: `60`, Type: "timeout"},
{Name: "slave_node_retry", Value: `3`, Type: "slave"},
{Name: "slave_ping_interval", Value: `60`, Type: "slave"},
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
{Name: "onedrive_chunk_retries", Value: `1`, Type: "retry"},
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
{Name: "login_captcha", Value: `0`, Type: "login"},
{Name: "reg_captcha", Value: `0`, Type: "login"},
{Name: "email_active", Value: `0`, Type: "register"},
{Name: "mail_activation_template", Value: `<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #009688; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">激活{siteTitle}账户</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您注册{siteTitle},请点击下方按钮完成账户激活。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{activationUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #009688; margin: 0; border-color: #009688; border-style: solid; border-width: 10px 20px;">激活账户</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>`, Type: "mail_template"},
{Name: "forget_captcha", Value: `0`, Type: "login"},
{Name: "mail_reset_pwd_template", Value: `<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>重设密码</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>`, Type: "mail_template"},
{Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"},
{Name: "hot_share_num", Value: `10`, Type: "share"},
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
{Name: "max_worker_num", Value: `10`, Type: "task"},
{Name: "max_parallel_transfer", Value: `4`, Type: "task"},
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
{Name: "temp_path", Value: "temp", Type: "path"},
{Name: "avatar_path", Value: "avatar", Type: "path"},
{Name: "avatar_size", Value: "2097152", Type: "avatar"},
{Name: "avatar_size_l", Value: "200", Type: "avatar"},
{Name: "avatar_size_m", Value: "130", Type: "avatar"},
{Name: "avatar_size_s", Value: "50", Type: "avatar"},
{Name: "home_view_method", Value: "icon", Type: "view"},
{Name: "share_view_method", Value: "list", Type: "view"},
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
{Name: "authn_enabled", Value: "0", Type: "authn"},
{Name: "captcha_type", Value: "normal", Type: "captcha"},
{Name: "captcha_height", Value: "60", Type: "captcha"},
{Name: "captcha_width", Value: "240", Type: "captcha"},
{Name: "captcha_mode", Value: "3", Type: "captcha"},
{Name: "captcha_ComplexOfNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_ComplexOfNoiseDot", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowHollowLine", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowNoiseDot", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowNoiseText", Value: "0", Type: "captcha"},
{Name: "captcha_IsShowSlimeLine", Value: "1", Type: "captcha"},
{Name: "captcha_IsShowSineLine", Value: "0", Type: "captcha"},
{Name: "captcha_CaptchaLen", Value: "6", Type: "captcha"},
{Name: "captcha_ReCaptchaKey", Value: "defaultKey", Type: "captcha"},
{Name: "captcha_ReCaptchaSecret", Value: "defaultSecret", Type: "captcha"},
{Name: "captcha_TCaptcha_CaptchaAppId", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_AppSecretKey", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_SecretId", Value: "", Type: "captcha"},
{Name: "captcha_TCaptcha_SecretKey", Value: "", Type: "captcha"},
{Name: "thumb_width", Value: "400", Type: "thumb"},
{Name: "thumb_height", Value: "300", Type: "thumb"},
{Name: "pwa_small_icon", Value: "/static/img/favicon.ico", Type: "pwa"},
{Name: "pwa_medium_icon", Value: "/static/img/logo192.png", Type: "pwa"},
{Name: "pwa_large_icon", Value: "/static/img/logo512.png", Type: "pwa"},
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
}
for _, value := range defaultSettings {
DB.Where(Setting{Name: value.Name}).Create(&value)
}
@@ -98,24 +202,20 @@ func addDefaultGroups() {
// 未找到初始管理组时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Group{
Name: "Admin",
Name: "管理员",
PolicyList: []uint{1},
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
WebDAVEnabled: true,
OptionsSerialized: GroupOption{
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
Aria2: true,
SourceBatchSize: 1000,
Aria2BatchSize: 50,
RedirectedSource: true,
AdvanceDelete: true,
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
Aria2: true,
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("Failed to create admin user group: %s", err)
util.Log().Panic("无法创建管理用户组, %s", err)
}
}
@@ -124,20 +224,17 @@ func addDefaultGroups() {
// 未找到初始注册会员时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Group{
Name: "User",
Name: "注册会员",
PolicyList: []uint{1},
MaxStorage: 1 * 1024 * 1024 * 1024,
ShareEnabled: true,
WebDAVEnabled: true,
OptionsSerialized: GroupOption{
ShareDownload: true,
SourceBatchSize: 10,
Aria2BatchSize: 1,
RedirectedSource: true,
ShareDownload: true,
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("Failed to create initial user group: %s", err)
util.Log().Panic("无法创建初始注册会员用户组, %s", err)
}
}
@@ -146,7 +243,7 @@ func addDefaultGroups() {
// 未找到初始游客用户组时,则创建
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Group{
Name: "Anonymous",
Name: "游客",
PolicyList: []uint{},
Policies: "[]",
OptionsSerialized: GroupOption{
@@ -154,7 +251,7 @@ func addDefaultGroups() {
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("Failed to create anonymous user group: %s", err)
util.Log().Panic("无法创建初始游客用户组, %s", err)
}
}
}
@@ -172,15 +269,15 @@ func addDefaultUser() {
defaultUser.GroupID = 1
err := defaultUser.SetPassword(password)
if err != nil {
util.Log().Panic("Failed to create password: %s", err)
util.Log().Panic("无法创建密码, %s", err)
}
if err := DB.Create(&defaultUser).Error; err != nil {
util.Log().Panic("Failed to create initial root user: %s", err)
util.Log().Panic("无法创建初始用户, %s", err)
}
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
util.Log().Info("Admin user name: " + c.Sprint("admin@cloudreve.org"))
util.Log().Info("Admin password: " + c.Sprint(password))
util.Log().Info("初始管理员账号:" + c.Sprint("admin@cloudreve.org"))
util.Log().Info("初始管理员密码:" + c.Sprint(password))
}
}
@@ -189,7 +286,7 @@ func addDefaultNode() {
if gorm.IsRecordNotFoundError(err) {
defaultAdminGroup := Node{
Name: "Master (Local machine)",
Name: "主机(本机)",
Status: NodeActive,
Type: MasterNodeType,
Aria2OptionsSerialized: Aria2Option{
@@ -198,7 +295,7 @@ func addDefaultNode() {
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
util.Log().Panic("Failed to create initial node: %s", err)
util.Log().Panic("无法创建初始节点记录, %s", err)
}
}
}

View File

@@ -10,8 +10,8 @@ import (
func TestMigration(t *testing.T) {
asserts := assert.New(t)
conf.DatabaseConfig.Type = "sqlite"
DB, _ = gorm.Open("sqlite", ":memory:")
conf.DatabaseConfig.Type = "sqlite3"
DB, _ = gorm.Open("sqlite3", ":memory:")
asserts.NotPanics(func() {
migration()

View File

@@ -3,7 +3,8 @@ package model
import (
"encoding/gob"
"encoding/json"
"github.com/gofrs/uuid"
"fmt"
"net/url"
"path"
"path/filepath"
"strconv"
@@ -57,20 +58,8 @@ type PolicyOption struct {
Region string `json:"region,omitempty"`
// ServerSideEndpoint 服务端请求使用的 Endpoint为空时使用 Policy.Server 字段
ServerSideEndpoint string `json:"server_side_endpoint,omitempty"`
// 分片上传的分片大小
ChunkSize uint64 `json:"chunk_size,omitempty"`
// 分片上传时是否需要预留空间
PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"`
// 每秒对存储端的 API 请求上限
TPSLimit float64 `json:"tps_limit,omitempty"`
// 每秒 API 请求爆发上限
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
// Set this to `true` to force the request to use path-style addressing,
// i.e., `http://s3.amazonaws.com/BUCKET/KEY `
S3ForcePathStyle bool `json:"s3_path_style"`
}
// thumbSuffix 支持缩略图处理的文件扩展名
var thumbSuffix = map[string][]string{
"local": {},
"qiniu": {".psd", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
@@ -125,7 +114,7 @@ func (policy *Policy) BeforeSave() (err error) {
return err
}
// SerializeOptions 将序列后的Option写入到数据库字段
//SerializeOptions 将序列后的Option写入到数据库字段
func (policy *Policy) SerializeOptions() (err error) {
optionsValue, err := json.Marshal(&policy.OptionsSerialized)
policy.Options = string(optionsValue)
@@ -159,7 +148,7 @@ func (policy *Policy) GeneratePath(uid uint, origin string) string {
func (policy *Policy) GenerateFileName(uid uint, origin string) string {
// 未开启自动重命名时,直接返回原始文件名
if !policy.AutoRename {
return origin
return policy.getOriginNameRule(origin)
}
fileRule := policy.FileNameRule
@@ -178,15 +167,35 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string {
"{hour}": time.Now().Format("15"),
"{minute}": time.Now().Format("04"),
"{second}": time.Now().Format("05"),
"{originname}": origin,
"{ext}": filepath.Ext(origin),
"{uuid}": uuid.Must(uuid.NewV4()).String(),
}
replaceTable["{originname}"] = policy.getOriginNameRule(origin)
fileRule = util.Replace(replaceTable, fileRule)
return fileRule
}
func (policy Policy) getOriginNameRule(origin string) string {
// 部分存储策略可以使用{origin}代表原始文件名
if origin == "" {
// 如果上游未传回原始文件名,则使用占位符,让云存储端替换
switch policy.Type {
case "qiniu":
// 七牛会将$(fname)自动替换为原始文件名
return "$(fname)"
case "local", "remote":
return origin
case "oss", "cos":
// OSS会将${filename}自动替换为原始文件名
return "${filename}"
case "upyun":
// Upyun会将{filename}{.suffix}自动替换为原始文件名
return "{filename}{.suffix}"
}
}
return origin
}
// IsDirectlyPreview 返回此策略下文件是否可以直接预览(不需要重定向)
func (policy *Policy) IsDirectlyPreview() bool {
return policy.Type == "local"
@@ -205,7 +214,18 @@ func (policy *Policy) IsThumbExist(name string) bool {
// IsTransitUpload 返回此策略上传给定size文件时是否需要服务端中转
func (policy *Policy) IsTransitUpload(size uint64) bool {
return policy.Type == "local"
if policy.Type == "local" {
return true
}
if policy.Type == "onedrive" && size < 4*1024*1024 {
return true
}
return false
}
// IsPathGenerateNeeded 返回此策略是否需要在生成上传凭证时生成存储路径
func (policy *Policy) IsPathGenerateNeeded() bool {
return policy.Type != "remote"
}
// IsThumbGenerateNeeded 返回此策略是否需要在上传后生成缩略图
@@ -213,24 +233,44 @@ func (policy *Policy) IsThumbGenerateNeeded() bool {
return policy.Type == "local"
}
// IsUploadPlaceholderWithSize 返回此策略创建上传会话时是否需要预留空间
func (policy *Policy) IsUploadPlaceholderWithSize() bool {
if policy.Type == "remote" {
return true
}
if util.ContainsString([]string{"onedrive", "oss", "qiniu", "cos", "s3"}, policy.Type) {
return policy.OptionsSerialized.PlaceholderWithSize
}
return false
}
// CanStructureBeListed 返回存储策略是否能被前台列物理目录
func (policy *Policy) CanStructureBeListed() bool {
return policy.Type != "local" && policy.Type != "remote"
}
// GetUploadURL 获取文件上传服务API地址
func (policy *Policy) GetUploadURL() string {
server, err := url.Parse(policy.Server)
if err != nil {
return policy.Server
}
controller, _ := url.Parse("")
switch policy.Type {
case "local", "onedrive":
return "/api/v3/file/upload"
case "remote":
controller, _ = url.Parse("/api/v3/slave/upload")
case "oss":
return "https://" + policy.BucketName + "." + policy.Server
case "cos":
return policy.Server
case "upyun":
return "https://v0.api.upyun.com/" + policy.BucketName
case "s3":
if policy.Server == "" {
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/", policy.BucketName,
policy.OptionsSerialized.Region)
}
if !strings.Contains(policy.Server, policy.BucketName) {
controller, _ = url.Parse("/" + policy.BucketName)
}
}
return server.ResolveReference(controller).String()
}
// SaveAndClearCache 更新并清理缓存
func (policy *Policy) SaveAndClearCache() error {
err := DB.Save(policy).Error

View File

@@ -104,7 +104,7 @@ func TestPolicy_GenerateFileName(t *testing.T) {
asserts.Equal("123.txt", testPolicy.GenerateFileName(1, "123.txt"))
testPolicy.Type = "oss"
asserts.Equal("origin", testPolicy.GenerateFileName(1, "origin"))
asserts.Equal("${filename}", testPolicy.GenerateFileName(1, ""))
}
// 重命名开启
@@ -145,23 +145,19 @@ func TestPolicy_GenerateFileName(t *testing.T) {
testPolicy.Type = "oss"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
asserts.Equal("1123${filename}", testPolicy.GenerateFileName(1, ""))
testPolicy.Type = "upyun"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
asserts.Equal("1123{filename}{.suffix}", testPolicy.GenerateFileName(1, ""))
testPolicy.Type = "qiniu"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
asserts.Equal("1123$(fname)", testPolicy.GenerateFileName(1, ""))
testPolicy.Type = "local"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123", testPolicy.GenerateFileName(1, ""))
testPolicy.Type = "local"
testPolicy.FileNameRule = "{ext}123{uuid}"
asserts.Contains(testPolicy.GenerateFileName(1, "123.txt"), ".txt123")
}
}
@@ -174,6 +170,78 @@ func TestPolicy_IsDirectlyPreview(t *testing.T) {
asserts.False(policy.IsDirectlyPreview())
}
func TestPolicy_GetUploadURL(t *testing.T) {
asserts := assert.New(t)
// 本地
{
cache.Set("setting_siteURL", "http://127.0.0.1", 0)
policy := Policy{Type: "local", Server: "http://127.0.0.1"}
asserts.Equal("/api/v3/file/upload", policy.GetUploadURL())
}
// 远程
{
policy := Policy{Type: "remote", Server: "http://127.0.0.1"}
asserts.Equal("http://127.0.0.1/api/v3/slave/upload", policy.GetUploadURL())
}
// OSS
{
policy := Policy{Type: "oss", BucketName: "base", Server: "127.0.0.1"}
asserts.Equal("https://base.127.0.0.1", policy.GetUploadURL())
}
// cos
{
policy := Policy{Type: "cos", BaseURL: "base", Server: "http://127.0.0.1"}
asserts.Equal("http://127.0.0.1", policy.GetUploadURL())
}
// upyun
{
policy := Policy{Type: "upyun", BucketName: "base", Server: "http://127.0.0.1"}
asserts.Equal("https://v0.api.upyun.com/base", policy.GetUploadURL())
}
// 未知
{
policy := Policy{Type: "unknown", Server: "http://127.0.0.1"}
asserts.Equal("http://127.0.0.1", policy.GetUploadURL())
}
// S3 未填写自动生成
{
policy := Policy{
Type: "s3",
Server: "",
BucketName: "bucket",
OptionsSerialized: PolicyOption{Region: "us-east"},
}
asserts.Equal("https://bucket.s3.us-east.amazonaws.com/", policy.GetUploadURL())
}
// s3 自己指定
{
policy := Policy{
Type: "s3",
Server: "https://s3.us-east.amazonaws.com/",
BucketName: "bucket",
OptionsSerialized: PolicyOption{Region: "us-east"},
}
asserts.Equal("https://s3.us-east.amazonaws.com/bucket", policy.GetUploadURL())
}
}
func TestPolicy_IsPathGenerateNeeded(t *testing.T) {
asserts := assert.New(t)
policy := Policy{Type: "qiniu"}
asserts.True(policy.IsPathGenerateNeeded())
policy.Type = "remote"
asserts.False(policy.IsPathGenerateNeeded())
}
func TestPolicy_ClearCache(t *testing.T) {
asserts := assert.New(t)
cache.Set("policy_202", 1, 0)
@@ -198,18 +266,15 @@ func TestPolicy_UpdateAccessKey(t *testing.T) {
func TestPolicy_Props(t *testing.T) {
asserts := assert.New(t)
policy := Policy{Type: "onedrive"}
policy.OptionsSerialized.PlaceholderWithSize = true
asserts.False(policy.IsThumbGenerateNeeded())
asserts.False(policy.IsTransitUpload(4))
asserts.True(policy.IsPathGenerateNeeded())
asserts.True(policy.IsTransitUpload(4))
asserts.False(policy.IsTransitUpload(5 * 1024 * 1024))
asserts.True(policy.CanStructureBeListed())
asserts.True(policy.IsUploadPlaceholderWithSize())
policy.Type = "local"
asserts.True(policy.IsThumbGenerateNeeded())
asserts.True(policy.IsPathGenerateNeeded())
asserts.False(policy.CanStructureBeListed())
asserts.False(policy.IsUploadPlaceholderWithSize())
policy.Type = "remote"
asserts.True(policy.IsUploadPlaceholderWithSize())
}
func TestPolicy_IsThumbExist(t *testing.T) {

View File

@@ -15,12 +15,12 @@ var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok {
util.Log().Info("Start executing database script %q.", name)
util.Log().Info("开始执行数据库脚本 [%s]", name)
script.Run(ctx)
return nil
}
return fmt.Errorf("Database script %q not exist.", name)
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
}
func Register(name string, script DBScript) {

View File

@@ -14,7 +14,7 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 查找用户
user, err := model.GetUserByID(1)
if err != nil {
util.Log().Panic("Initial admin user not exist: %s", err)
util.Log().Panic("初始管理员用户不存在, %s", err)
}
// 生成密码
@@ -23,9 +23,9 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 更改为新密码
user.SetPassword(password)
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
util.Log().Panic("Failed to update password: %s", err)
util.Log().Panic("密码更改失败, %s", err)
}
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
util.Log().Info("Initial admin user password changed to:" + c.Sprint(password))
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password))
}

View File

@@ -25,9 +25,9 @@ func (script UserStorageCalibration) Run(ctx context.Context) {
model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total)
// 更新用户的容量
if user.Storage != total.Total {
util.Log().Info("Calibrate used storage for user %q, from %d to %d.", user.Email,
util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email,
user.Storage, total.Total)
model.DB.Model(&user).Update("storage", total.Total)
}
model.DB.Model(&user).Update("storage", total.Total)
}
}

View File

@@ -52,9 +52,6 @@ func TestUserStorageCalibration_Run(t *testing.T) {
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
asserts.NoError(mock.ExpectationsWereMet())
}

View File

@@ -23,11 +23,6 @@ func IsTrueVal(val string) bool {
// GetSettingByName 用 Name 获取设置值
func GetSettingByName(name string) string {
return GetSettingByNameFromTx(DB, name)
}
// GetSettingByNameFromTx 用 Name 获取设置值,使用事务
func GetSettingByNameFromTx(tx *gorm.DB, name string) string {
var setting Setting
// 优先从缓存中查找
@@ -37,31 +32,17 @@ func GetSettingByNameFromTx(tx *gorm.DB, name string) string {
}
// 尝试数据库中查找
if tx == nil {
tx = DB
if tx == nil {
return ""
if DB != nil {
result := DB.Where("name = ?", name).First(&setting)
if result.Error == nil {
_ = cache.Set(cacheKey, setting.Value, -1)
return setting.Value
}
}
result := tx.Where("name = ?", name).First(&setting)
if result.Error == nil {
_ = cache.Set(cacheKey, setting.Value, -1)
return setting.Value
}
return ""
}
// GetSettingByNameWithDefault 用 Name 获取设置值, 取不到时使用缺省值
func GetSettingByNameWithDefault(name, fallback string) string {
res := GetSettingByName(name)
if res == "" {
return fallback
}
return res
}
// GetSettingByNames 用多个 Name 获取设置值
func GetSettingByNames(names ...string) map[string]string {
var queryRes []Setting

View File

@@ -59,15 +59,6 @@ func TestGetSettingByType(t *testing.T) {
asserts.Equal(map[string]string{}, settings)
}
func TestGetSettingByNameWithDefault(t *testing.T) {
a := assert.New(t)
rows := sqlmock.NewRows([]string{"name", "value", "type"})
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings := GetSettingByNameWithDefault("123", "123321")
a.Equal("123321", settings)
}
func TestGetSettingByNames(t *testing.T) {
cache.Store = cache.NewMemoStore()
asserts := assert.New(t)

View File

@@ -36,7 +36,7 @@ type Share struct {
// Create 创建分享
func (share *Share) Create() (uint, error) {
if err := DB.Create(share).Error; err != nil {
util.Log().Warning("Failed to insert share record: %s", err)
util.Log().Warning("无法插入数据库记录, %s", err)
return 0, err
}
return share.ID, nil
@@ -131,9 +131,9 @@ func (share *Share) CanBeDownloadBy(user *User) error {
// 用户组权限
if !user.Group.OptionsSerialized.ShareDownload {
if user.IsAnonymous() {
return errors.New("you must login to download")
return errors.New("未登录用户无法下载")
}
return errors.New("your group has no permission to download")
return errors.New("您当前的用户组无权下载")
}
return nil
}

View File

@@ -1,47 +0,0 @@
package model
import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/jinzhu/gorm"
"net/url"
)
// SourceLink represent a shared file source link
type SourceLink struct {
gorm.Model
FileID uint // corresponding file ID
Name string // name of the file while creating the source link, for annotation
Downloads int // 下载数
// 关联模型
File File `gorm:"save_associations:false:false"`
}
// Link gets the URL of a SourceLink
func (s *SourceLink) Link() (string, error) {
baseURL := GetSiteURL()
linkPath, err := url.Parse(fmt.Sprintf("/f/%s/%s", hashid.HashID(s.ID, hashid.SourceLinkID), s.File.Name))
if err != nil {
return "", err
}
return baseURL.ResolveReference(linkPath).String(), nil
}
// GetTasksByID queries source link based on ID
func GetSourceLinkByID(id interface{}) (*SourceLink, error) {
link := &SourceLink{}
result := DB.Where("id = ?", id).First(link)
files, _ := GetFilesByIDs([]uint{link.FileID}, 0)
if len(files) > 0 {
link.File = files[0]
}
return link, result.Error
}
// Viewed 增加访问次数
func (s *SourceLink) Downloaded() {
s.Downloads++
DB.Model(s).UpdateColumn("downloads", gorm.Expr("downloads + ?", 1))
}

View File

@@ -1,52 +0,0 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSourceLink_Link(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
// 失败
{
s.File.Name = string([]byte{0x7f})
res, err := s.Link()
a.Error(err)
a.Empty(res)
}
// 成功
{
s.File.Name = "filename"
res, err := s.Link()
a.NoError(err)
a.Contains(res, s.Name)
}
}
func TestGetSourceLinkByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := GetSourceLinkByID(1)
a.NoError(err)
a.NotNil(res)
a.EqualValues(2, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestSourceLink_Downloaded(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
s.Downloaded()
a.NoError(mock.ExpectationsWereMet())
}

View File

@@ -26,7 +26,7 @@ const (
// Create 创建标签记录
func (tag *Tag) Create() (uint, error) {
if err := DB.Create(tag).Error; err != nil {
util.Log().Warning("Failed to insert tag record: %s", err)
util.Log().Warning("无法插入离线下载记录, %s", err)
return 0, err
}
return tag.ID, nil

View File

@@ -56,8 +56,8 @@ func TestGetTagsByUID(t *testing.T) {
func TestGetTagsByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag"))
res, err := GetTagsByID(1, 1)
res, err := GetTagsByUID(1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues("tag", res.Name)
asserts.EqualValues("tag", res[0].Name)
}

View File

@@ -19,7 +19,7 @@ type Task struct {
// Create 创建任务记录
func (task *Task) Create() (uint, error) {
if err := DB.Create(task).Error; err != nil {
util.Log().Warning("Failed to insert task record: %s", err)
util.Log().Warning("无法插入任务记录, %s", err)
return 0, err
}
return task.ID, nil
@@ -64,7 +64,7 @@ func ListTasks(uid uint, page, pageSize int, order string) ([]Task, int) {
dbChain = dbChain.Where("user_id = ?", uid)
// 计算总数用于分页
dbChain.Model(&Task{}).Count(&total)
dbChain.Model(&Share{}).Count(&total)
// 查询记录
dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks)

View File

@@ -3,7 +3,6 @@ package model
import (
"crypto/md5"
"crypto/sha1"
"encoding/gob"
"encoding/hex"
"encoding/json"
"strings"
@@ -36,8 +35,8 @@ type User struct {
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-" gorm:"size:4294967295"`
Authn string `gorm:"size:4294967295"`
Options string `json:"-" gorm:"type:text"`
Authn string `gorm:"type:text"`
// 关联模型
Group Group `gorm:"save_associations:false:false"`
@@ -47,10 +46,6 @@ type User struct {
OptionsSerialized UserOption `gorm:"-"`
}
func init() {
gob.Register(User{})
}
// UserOption 用户个性化配置字段
type UserOption struct {
ProfileOff bool `json:"profile_off,omitempty"`
@@ -94,11 +89,6 @@ func (user *User) IncreaseStorage(size uint64) bool {
return false
}
// ChangeStorage 更新用户容量
func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error {
return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error
}
// IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量
func (user *User) IncreaseStorageWithoutCheck(size uint64) {
if size == 0 {

View File

@@ -11,7 +11,6 @@ type Webdav struct {
Password string `gorm:"unique_index:password_only_on"` // 应用密码
UserID uint `gorm:"unique_index:password_only_on"` // 用户ID
Root string `gorm:"type:text"` // 根目录
Readonly bool `gorm:"type:bool"` // 是否只读
}
// Create 创建账户
@@ -40,8 +39,3 @@ func ListWebDAVAccounts(uid uint) []Webdav {
func DeleteWebDAVAccountByID(id, uid uint) {
DB.Where("user_id = ? and id = ?", uid, id).Delete(&Webdav{})
}
// UpdateWebDAVAccountReadonlyByID 根据账户ID和UID更新账户的只读性
func UpdateWebDAVAccountReadonlyByID(id, uid uint, readonly bool) {
DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).UpdateColumn("readonly", readonly)
}

View File

@@ -3,6 +3,8 @@ package aria2
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"net/url"
"sync"
"time"
@@ -12,8 +14,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
)
// Instance 默认使用的Aria2处理实例
@@ -40,7 +40,7 @@ func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控

View File

@@ -46,15 +46,13 @@ const (
Canceled
// Unknown 未知状态
Unknown
// Seeding 做种中
Seeding
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
ErrNotEnabled = serializer.NewError(serializer.CodeNoPermissionErr, "离线下载功能未开启", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
ErrUserNotFound = serializer.NewError(serializer.CodeNotFound, "无法找到任务创建者", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
@@ -96,14 +94,11 @@ func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status rpc.StatusInfo) int {
switch status.Status {
func GetStatus(status string) int {
switch status {
case "complete":
return Complete
case "active":
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
return Seeding
}
return Downloading
case "waiting":
return Ready

View File

@@ -1,11 +1,9 @@
package common
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDummyAria2(t *testing.T) {
@@ -37,18 +35,11 @@ func TestDummyAria2(t *testing.T) {
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
TotalLength: "100", CompletedLength: "50"}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
TotalLength: "100", CompletedLength: "100"}), Seeding)
a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
a.Equal(GetStatus("complete"), Complete)
a.Equal(GetStatus("active"), Downloading)
a.Equal(GetStatus("waiting"), Ready)
a.Equal(GetStatus("paused"), Paused)
a.Equal(GetStatus("error"), Error)
a.Equal(GetStatus("removed"), Canceled)
a.Equal(GetStatus("unknown"), Unknown)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
@@ -45,7 +46,7 @@ func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
} else {
monitor.setErrorStatus(errors.New("node not avaliable"))
monitor.setErrorStatus(errors.New("节点不可用"))
}
}
@@ -77,12 +78,11 @@ func (monitor *Monitor) Update() bool {
if err != nil {
monitor.retried++
util.Log().Warning("Cannot get status of download task %q: %s", monitor.Task.GID, err)
util.Log().Warning("无法获取下载任务[%s]的状态,%s", monitor.Task.GID, err)
// 十次重试后认定为任务失败
if monitor.retried > MAX_RETRY {
util.Log().Warning("Cannot get status of download task %qexceed maximum retry threshold: %s",
monitor.Task.GID, err)
util.Log().Warning("无法获取下载任务[%s]的状态,超过最大重试次数限制,%s", monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
@@ -94,7 +94,7 @@ func (monitor *Monitor) Update() bool {
// 磁力链下载需要跟随
if len(status.FollowedBy) > 0 {
util.Log().Debug("Redirected download task from %q to %q.", monitor.Task.GID, status.FollowedBy[0])
util.Log().Debug("离线下载[%s]重定向至[%s]", monitor.Task.GID, status.FollowedBy[0])
monitor.Task.GID = status.FollowedBy[0]
monitor.Task.Save()
return false
@@ -102,28 +102,28 @@ func (monitor *Monitor) Update() bool {
// 更新任务信息
if err := monitor.UpdateTaskInfo(status); err != nil {
util.Log().Warning("Failed to update status of download task %q: %s", monitor.Task.GID, err)
util.Log().Warning("无法更新下载任务[%s]的任务信息[%s]", monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status)
util.Log().Debug("离线下载[%s]更新状态[%s]", status.Gid, status.Status)
switch common.GetStatus(status) {
case common.Complete, common.Seeding:
switch status.Status {
case "complete":
return monitor.Complete(task.TaskPoll)
case common.Error:
case "error":
return monitor.Error(status)
case common.Downloading, common.Ready, common.Paused:
case "active", "waiting", "paused":
return false
case common.Canceled:
case "removed":
monitor.Task.Status = common.Canceled
monitor.Task.Save()
monitor.RemoveTempFolder()
return true
default:
util.Log().Warning("Download task %q returns unknown status %q.", monitor.Task.GID, status.Status)
util.Log().Warning("下载任务[%s]返回未知状态信息[%s]", monitor.Task.GID, status.Status)
return true
}
}
@@ -133,7 +133,7 @@ func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid
monitor.Task.Status = common.GetStatus(status)
monitor.Task.Status = common.GetStatus(status.Status)
// 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
@@ -191,12 +191,12 @@ func (monitor *Monitor) ValidateFile() error {
defer fs.Recycle()
// 创建上下文环境
file := &fsctx.FileStream{
ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{
Size: monitor.Task.TotalSize,
}
})
// 验证用户容量
if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil {
if err := filesystem.HookValidateCapacityWithoutIncrease(ctx, fs); err != nil {
return err
}
@@ -205,11 +205,11 @@ func (monitor *Monitor) ValidateFile() error {
if fileInfo.Selected == "true" {
// 创建上下文环境
fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
file := &fsctx.FileStream{
ctx := context.WithValue(context.Background(), fsctx.FileHeaderCtx, local.FileStream{
Size: fileSize,
Name: filepath.Base(fileInfo.Path),
}
if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil {
})
if err := filesystem.HookValidateFile(ctx, fs); err != nil {
return err
}
}
@@ -236,40 +236,6 @@ func (monitor *Monitor) RemoveTempFolder() {
// Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(pool task.Pool) bool {
// 未开始转存,提交转存任务
if monitor.Task.TaskID == 0 {
return monitor.transfer(pool)
}
// 做种完成
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 转存完成,回收下载目录
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
job, err := task.NewRecycleTask(monitor.Task)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交回收任务
pool.Submit(job)
return true
}
}
return false
}
func (monitor *Monitor) transfer(pool task.Pool) bool {
// 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
@@ -304,7 +270,7 @@ func (monitor *Monitor) transfer(pool task.Pool) bool {
monitor.Task.TaskID = job.Model().ID
monitor.Task.Save()
return false
return true
}
func (monitor *Monitor) setErrorStatus(err error) {

View File

@@ -3,8 +3,6 @@ package monitor
import (
"database/sql"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
@@ -15,6 +13,7 @@ import (
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
)
var mock sqlmock.Sqlmock
@@ -432,14 +431,6 @@ func TestMonitor_Complete(t *testing.T) {
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
a.False(m.Complete(mockPool))
m.Task.StatusInfo.Status = "complete"
a.True(m.Complete(mockPool))
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)

View File

@@ -4,27 +4,35 @@ package rpc
// StatusInfo represents response of aria2.tellStatus
type StatusInfo struct {
Gid string `json:"gid"` // GID of the download.
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec.
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
PieceLength string `json:"pieceLength"` // Piece length in bytes.
NumPieces string `json:"numPieces"` // The number of pieces.
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
Gid string `json:"gid"` // GID of the download.
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
UploadSpeed string `json:"uploadSpeed"` // Upload speed of this download measured in bytes/sec.
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
PieceLength string `json:"pieceLength"` // Piece length in bytes.
NumPieces string `json:"numPieces"` // The number of pieces.
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
} `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
}
// URIInfo represents an element of response of aria2.getUris
@@ -52,7 +60,7 @@ type PeerInfo struct {
AmChoking string `json:"amChoking"` // true if aria2 is choking the peer. Otherwise false.
PeerChoking string `json:"peerChoking"` // true if the peer is choking aria2. Otherwise false.
DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec) that this client obtains from the peer.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed(byte/sec) that this client uploads to the peer.
UploadSpeed string `json:"uploadSpeed"` // Upload speed(byte/sec) that this client uploads to the peer.
Seeder string `json:"seeder"` // true if this peer is a seeder. Otherwise false.
}
@@ -92,13 +100,3 @@ type Method struct {
Name string `json:"methodName"` // Method name to call
Params []interface{} `json:"params"` // Array containing parameters to the method call
}
type BitTorrentInfo struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
}

View File

@@ -17,14 +17,12 @@ import (
)
var (
ErrAuthFailed = serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
ErrAuthFailed = serializer.NewError(serializer.CodeNoPermissionErr, "鉴权失败", nil)
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
)
const CrHeaderPrefix = "X-Cr-"
// General 通用的认证接口
var General Auth
@@ -66,12 +64,12 @@ func CheckRequest(instance Auth, r *http.Request) error {
return instance.Check(getSignContent(r), sign[0])
}
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果 Header 中包含 `X-Policy`
// 则不对正文签名。返回待签名/验证的字符串
func getSignContent(r *http.Request) (rawSignString string) {
// 读取所有body正文
var body = []byte{}
if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
if _, ok := r.Header["X-Cr-Policy"]; !ok {
if r.Body != nil {
body, _ = ioutil.ReadAll(r.Body)
_ = r.Body.Close()
@@ -82,7 +80,7 @@ func getSignContent(r *http.Request) (rawSignString string) {
// 决定要签名的header
var signedHeader []string
for k, _ := range r.Header {
if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
if strings.HasPrefix(k, "X-Cr-") && k != "X-Cr-Filename" {
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
}
}
@@ -136,7 +134,7 @@ func Init() {
} else {
secretKey = conf.SlaveConfig.Secret
if secretKey == "" {
util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
util.Log().Panic("未指定 SlaveSecret,请前往配置文件中指定")
}
}
General = HMACAuth{

12
pkg/cache/driver.go vendored
View File

@@ -2,7 +2,6 @@ package cache
import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
@@ -10,7 +9,9 @@ import (
var Store Driver = NewMemoStore()
// Init 初始化缓存
func Init(isSlave bool) {
func Init() {
//Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0")
//return
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
Store = NewRedisStore(
10,
@@ -20,13 +21,6 @@ func Init(isSlave bool) {
conf.RedisConfig.DB,
)
}
if isSlave {
err := Store.Sets(conf.OptionOverwrite, "setting_")
if err != nil {
util.Log().Warning("Failed to overwrite database setting: %s", err)
}
}
}
// Driver 键值缓存存储容器

View File

@@ -56,10 +56,6 @@ func TestInit(t *testing.T) {
asserts := assert.New(t)
asserts.NotPanics(func() {
Init(false)
})
asserts.NotPanics(func() {
Init(true)
Init()
})
}

2
pkg/cache/memo.go vendored
View File

@@ -53,7 +53,7 @@ func (store *MemoStore) GarbageCollect() {
store.Store.Range(func(key, value interface{}) bool {
if item, ok := value.(itemWithTTL); ok {
if item.expires > 0 && item.expires < time.Now().Unix() {
util.Log().Debug("Cache %q is garbage collected.", key.(string))
util.Log().Debug("回收垃圾[%s]", key.(string))
store.Store.Delete(key)
}
}

2
pkg/cache/redis.go vendored
View File

@@ -66,7 +66,7 @@ func NewRedisStore(size int, network, address, password, database string) *Redis
redis.DialPassword(password),
)
if err != nil {
util.Log().Warning("Failed to create Redis connection: %s", err)
util.Log().Warning("无法创建Redis连接%s", err)
return nil, err
}
return c, nil

View File

@@ -8,5 +8,5 @@ import (
var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)

View File

@@ -161,7 +161,7 @@ func (r *rpcService) Init() error {
// 解析RPC服务地址
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
if err != nil {
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
util.Log().Warning("无法解析主机 Aria2 RPC 服务地址,%s", err)
return err
}
server.Path = "/jsonrpc"
@@ -171,7 +171,7 @@ func (r *rpcService) Init() error {
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
if err != nil {
util.Log().Warning("Failed to parse aria2 options: %s", err)
util.Log().Warning("无法解析主机 Aria2 配置,%s", err)
return err
}
}
@@ -221,7 +221,7 @@ func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := r.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
util.Log().Debug("无法获取离线下载状态,%s稍后重试", err)
time.Sleep(r.retryDuration)
res, err = r.Caller.TellStatus(task.GID)
}
@@ -233,7 +233,7 @@ func (r *rpcService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := r.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err)
}
return err
@@ -264,7 +264,7 @@ func (s *rpcService) DeleteTempFile(task *model.Download) error {
time.Sleep(d)
err := os.RemoveAll(src)
if err != nil {
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err)
}
}(s.deletePaddingDuration, task.Parent)

View File

@@ -42,7 +42,7 @@ func Init() {
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("Failed to initialize node pool: %s", err)
util.Log().Warning("节点池初始化失败, %s", err)
}
}
@@ -83,7 +83,7 @@ func (pool *NodePool) GetNodeByID(id uint) Node {
}
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
util.Log().Debug("从机节点 [ID=%d] 状态变更 [Active=%t]", id, isActive)
var node Node
pool.lock.Lock()
if n, ok := pool.inactive[id]; ok {

View File

@@ -1,15 +1,11 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
@@ -44,7 +40,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
var endpoint *url.URL
if serverURL, err := url.Parse(node.Model.Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave/")
controller, _ = url.Parse("/api/v3/slave")
endpoint = serverURL.ResolveReference(controller)
}
@@ -172,7 +168,7 @@ func (node *SlaveNode) StartPingLoop() {
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
pingTicker := time.Duration(0)
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
retry := 0
recoverMode := false
isFirstLoop := true
@@ -185,39 +181,39 @@ loop:
pingTicker = tickDuration
}
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil {
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
retry++
if retry >= model.GetIntSetting("slave_node_retry", 3) {
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
util.Log().Debug("从机节点 [%s] Ping 重试已达到最大限制,将从机节点标记为不可用", node.Model.Name)
node.changeStatus(false)
if !recoverMode {
// 启动恢复监控循环
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
util.Log().Debug("从机节点 [%s] 进入恢复模式", node.Model.Name)
pingTicker = recoverDuration
recoverMode = true
}
}
} else {
if recoverMode {
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
pingTicker = tickDuration
recoverMode = false
isFirstLoop = true
}
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
node.changeStatus(true)
retry = 0
}
case <-node.close:
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
util.Log().Debug("从机节点 [%s] 收到关闭信号", node.Model.Name)
break loop
}
}
@@ -412,40 +408,3 @@ func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
return strings.NewReader(string(reqBodyEncoded)), nil
}
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
}
// 解析回调服务端响应
response, err := resp.DecodeResponse()
if err != nil {
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
return serializer.NewError(serializer.CodeCallbackError, msg, err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

View File

@@ -1,12 +1,8 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
@@ -445,115 +441,3 @@ func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a.NoError(err)
}
}
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}

View File

@@ -3,7 +3,7 @@ package conf
import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-ini/ini"
"github.com/go-playground/validator/v10"
"gopkg.in/go-playground/validator.v9"
)
// database 数据库
@@ -17,7 +17,6 @@ type database struct {
DBFile string
Port int
Charset string
UnixSocket bool
}
// system 系统通用配置
@@ -27,8 +26,6 @@ type system struct {
Debug bool
SessionSecret string
HashIDSalt string
GracePeriod int `validate:"gte=0"`
ProxyHeader string `validate:"required_with=Listen"`
}
type ssl struct {
@@ -39,7 +36,6 @@ type ssl struct {
type unix struct {
Listen string
Perm uint32
}
// slave 作为slave存储端配置
@@ -49,6 +45,21 @@ type slave struct {
SignatureTTL int `validate:"omitempty,gte=1"`
}
// captcha 验证码配置
type captcha struct {
Height int `validate:"gte=0"`
Width int `validate:"gte=0"`
Mode int `validate:"gte=0,lte=3"`
ComplexOfNoiseText int `validate:"gte=0,lte=2"`
ComplexOfNoiseDot int `validate:"gte=0,lte=2"`
IsShowHollowLine bool
IsShowNoiseDot bool
IsShowNoiseText bool
IsShowSlimeLine bool
IsShowSineLine bool
CaptchaLen int `validate:"gt=0"`
}
// redis 配置
type redis struct {
Network string
@@ -57,6 +68,17 @@ type redis struct {
DB string
}
// 缩略图 配置
type thumb struct {
MaxWidth uint
MaxHeight uint
FileSuffix string `validate:"min=1"`
MaxTaskCount int
EncodeMethod string `validate:"eq=jpg|eq=png"`
EncodeQuality int `validate:"gte=1,lte=100"`
GCAfterGen bool
}
// 跨域配置
type cors struct {
AllowOrigins []string
@@ -64,14 +86,11 @@ type cors struct {
AllowHeaders []string
AllowCredentials bool
ExposeHeaders []string
SameSite string
Secure bool
}
var cfg *ini.File
const defaultConf = `[System]
Debug = false
Mode = master
Listen = :5212
SessionSecret = {SessionSecret}
@@ -90,13 +109,13 @@ func Init(path string) {
}, defaultConf)
f, err := util.CreatNestedFile(path)
if err != nil {
util.Log().Panic("Failed to create config file: %s", err)
util.Log().Panic("无法创建配置文件, %s", err)
}
// 写入配置文件
_, err = f.WriteString(confContent)
if err != nil {
util.Log().Panic("Failed to write config file: %s", err)
util.Log().Panic("无法写入配置文件, %s", err)
}
f.Close()
@@ -104,7 +123,7 @@ func Init(path string) {
cfg, err = ini.Load(path)
if err != nil {
util.Log().Panic("Failed to parse config file %q: %s", path, err)
util.Log().Panic("无法解析配置文件 '%s': %s", path, err)
}
sections := map[string]interface{}{
@@ -112,22 +131,19 @@ func Init(path string) {
"System": SystemConfig,
"SSL": SSLConfig,
"UnixSocket": UnixConfig,
"Captcha": CaptchaConfig,
"Redis": RedisConfig,
"Thumbnail": ThumbConfig,
"CORS": CORSConfig,
"Slave": SlaveConfig,
}
for sectionName, sectionStruct := range sections {
err = mapSection(sectionName, sectionStruct)
if err != nil {
util.Log().Panic("Failed to parse config section %q: %s", sectionName, err)
util.Log().Panic("配置文件 %s 分区解析失败: %s", sectionName, err)
}
}
// 映射数据库配置覆盖
for _, key := range cfg.Section("OptionOverwrite").Keys() {
OptionOverwrite[key.Name()] = key.Value()
}
// 重设log等级
if !SystemConfig.Debug {
util.Level = util.LevelInformational

View File

@@ -56,11 +56,7 @@ User = root
Password = root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_
[OptionOverwrite]
key=value
`
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
defer func() { err = os.Remove("testConf.ini") }()
if err != nil {
@@ -69,7 +65,6 @@ key=value
asserts.NotPanics(func() {
Init("testConf.ini")
})
asserts.Equal(OptionOverwrite["key"], "value")
}
func TestMapSection(t *testing.T) {

View File

@@ -1,5 +1,7 @@
package conf
import "github.com/mojocn/base64Captcha"
// RedisConfig Redis服务器配置
var RedisConfig = &redis{
Network: "tcp",
@@ -10,19 +12,32 @@ var RedisConfig = &redis{
// DatabaseConfig 数据库配置
var DatabaseConfig = &database{
Type: "UNSET",
Charset: "utf8",
DBFile: "cloudreve.db",
Port: 3306,
UnixSocket: false,
Type: "UNSET",
Charset: "utf8",
DBFile: "cloudreve.db",
Port: 3306,
}
// SystemConfig 系统公用配置
var SystemConfig = &system{
Debug: false,
Mode: "master",
Listen: ":5212",
ProxyHeader: "X-Forwarded-For",
Debug: false,
Mode: "master",
Listen: ":5212",
}
// CaptchaConfig 验证码配置
var CaptchaConfig = &captcha{
Height: 60,
Width: 240,
Mode: 3,
ComplexOfNoiseText: base64Captcha.CaptchaComplexLower,
ComplexOfNoiseDot: base64Captcha.CaptchaComplexLower,
IsShowHollowLine: false,
IsShowNoiseDot: false,
IsShowNoiseText: false,
IsShowSlimeLine: false,
IsShowSineLine: false,
CaptchaLen: 6,
}
// CORSConfig 跨域配置
@@ -32,8 +47,17 @@ var CORSConfig = &cors{
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
AllowCredentials: false,
ExposeHeaders: nil,
SameSite: "Default",
Secure: false,
}
// ThumbConfig 缩略图配置
var ThumbConfig = &thumb{
MaxWidth: 400,
MaxHeight: 300,
FileSuffix: "._thumb",
MaxTaskCount: -1,
EncodeMethod: "jpg",
GCAfterGen: false,
EncodeQuality: 85,
}
// SlaveConfig 从机配置
@@ -51,5 +75,3 @@ var SSLConfig = &ssl{
var UnixConfig = &unix{
Listen: "",
}
var OptionOverwrite = map[string]interface{}{}

View File

@@ -1,13 +1,13 @@
package conf
// BackendVersion 当前后端版本号
var BackendVersion = "3.7.1"
var BackendVersion = "3.4.3"
// RequiredDBVersion 与当前版本匹配的数据库版本
var RequiredDBVersion = "3.7.1"
var RequiredDBVersion = "3.4.0"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
var RequiredStaticVersion = "3.7.1"
var RequiredStaticVersion = "3.4.3"
// IsPro 是否为Pro版本
var IsPro = "false"

View File

@@ -1,7 +1,6 @@
package crontab
import (
"context"
"os"
"path/filepath"
"strings"
@@ -9,7 +8,6 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
@@ -22,7 +20,7 @@ func garbageCollect() {
collectCache(store)
}
util.Log().Info("Crontab job \"cron_garbage_collect\" complete.")
util.Log().Info("定时任务 [cron_garbage_collect] 执行完毕")
}
func collectArchiveFile() {
@@ -36,64 +34,22 @@ func collectArchiveFile() {
if err == nil && !info.IsDir() &&
strings.HasPrefix(filepath.Base(path), "archive_") &&
time.Now().Sub(info.ModTime()).Seconds() > float64(expires) {
util.Log().Debug("Delete expired batch download temp file %q.", path)
util.Log().Debug("删除过期打包下载临时文件 [%s]", path)
// 删除符合条件的文件
if err := os.Remove(path); err != nil {
util.Log().Debug("Failed to delete temp file %q: %s", path, err)
util.Log().Debug("临时文件 [%s] 删除失败 , %s", path, err)
}
}
return nil
})
if err != nil {
util.Log().Debug("Crontab job cannot list temp batch download folder: %s", err)
util.Log().Debug("[定时任务] 无法列取临时打包目录")
}
}
func collectCache(store *cache.MemoStore) {
util.Log().Debug("Cleanup memory cache.")
util.Log().Debug("清理内存缓存")
store.GarbageCollect()
}
func uploadSessionCollect() {
placeholders := model.GetUploadPlaceholderFiles(0)
// 将过期的上传会话按照用户分组
userToFiles := make(map[uint][]uint)
for _, file := range placeholders {
_, sessionExist := cache.Get(filesystem.UploadSessionCachePrefix + *file.UploadSessionID)
if sessionExist {
continue
}
if _, ok := userToFiles[file.UserID]; !ok {
userToFiles[file.UserID] = make([]uint, 0)
}
userToFiles[file.UserID] = append(userToFiles[file.UserID], file.ID)
}
// 删除过期的会话
for uid, filesIDs := range userToFiles {
user, err := model.GetUserByID(uid)
if err != nil {
util.Log().Warning("Owner of the upload session cannot be found: %s", err)
continue
}
fs, err := filesystem.NewFileSystem(&user)
if err != nil {
util.Log().Warning("Failed to initialize filesystem: %s", err)
continue
}
if err = fs.Delete(context.Background(), []uint{}, filesIDs, false, false); err != nil {
util.Log().Warning("Failed to delete upload session: %s", err)
}
fs.Recycle()
}
util.Log().Info("Crontab job \"cron_recycle_upload_session\" complete.")
}

View File

@@ -19,27 +19,22 @@ func Reload() {
// Init 初始化定时任务
func Init() {
util.Log().Info("Initialize crontab jobs...")
util.Log().Info("初始化定时任务...")
// 读取cron日程设置
options := model.GetSettingByNames(
"cron_garbage_collect",
"cron_recycle_upload_session",
)
options := model.GetSettingByNames("cron_garbage_collect")
Cron := cron.New()
for k, v := range options {
var handler func()
switch k {
case "cron_garbage_collect":
handler = garbageCollect
case "cron_recycle_upload_session":
handler = uploadSessionCollect
default:
util.Log().Warning("Unknown crontab job type %q, skipping...", k)
util.Log().Warning("未知定时任务类型 [%s],跳过", k)
continue
}
if _, err := Cron.AddFunc(v, handler); err != nil {
util.Log().Warning("Failed to start crontab job %q: %s", k, err)
util.Log().Warning("无法启动定时任务 [%s] , %s", k, err)
}
}

View File

@@ -15,7 +15,7 @@ var Lock sync.RWMutex
// Init 初始化
func Init() {
util.Log().Debug("Initializing email sending queue...")
util.Log().Debug("邮件队列初始化")
Lock.Lock()
defer Lock.Unlock()

View File

@@ -15,9 +15,9 @@ type Driver interface {
var (
// ErrChanNotOpen 邮件队列未开启
ErrChanNotOpen = errors.New("email queue is not started")
ErrChanNotOpen = errors.New("邮件队列未开启")
// ErrNoActiveDriver 无可用邮件发送服务
ErrNoActiveDriver = errors.New("no avaliable email provider")
ErrNoActiveDriver = errors.New("无可用邮件发送服务")
)
// Send 发送邮件

View File

@@ -68,7 +68,7 @@ func (client *SMTP) Init() {
defer func() {
if err := recover(); err != nil {
client.chOpen = false
util.Log().Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err)
util.Log().Error("邮件发送队列出现异常, %s ,10 秒后重置", err)
time.Sleep(time.Duration(10) * time.Second)
client.Init()
}
@@ -91,7 +91,7 @@ func (client *SMTP) Init() {
select {
case m, ok := <-client.ch:
if !ok {
util.Log().Debug("Email queue closing...")
util.Log().Debug("邮件队列关闭")
client.chOpen = false
return
}
@@ -102,15 +102,15 @@ func (client *SMTP) Init() {
open = true
}
if err := mail.Send(s, m); err != nil {
util.Log().Warning("Failed to send email: %s", err)
util.Log().Warning("邮件发送失败, %s", err)
} else {
util.Log().Debug("Email sent.")
util.Log().Debug("邮件已发送")
}
// 长时间没有新邮件则关闭SMTP连接
case <-time.After(time.Duration(client.Config.Keepalive) * time.Second):
if open {
if err := s.Close(); err != nil {
util.Log().Warning("Failed to close SMTP connection: %s", err)
util.Log().Warning("无法关闭 SMTP 连接 %s", err)
}
open = false
}

View File

@@ -2,9 +2,11 @@ package filesystem
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
@@ -16,7 +18,8 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/mholt/archiver/v4"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
/* ===============
@@ -25,17 +28,17 @@ import (
*/
// Compress 创建给定目录和文件的压缩文件
func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error {
func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint, isArchive bool) (string, error) {
// 查找待压缩目录
folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
if err != nil && len(folderIDs) != 0 {
return ErrDBListObjects
return "", ErrDBListObjects
}
// 查找待压缩文件
files, err := model.GetFilesByIDs(fileIDs, fs.User.ID)
if err != nil && len(fileIDs) != 0 {
return ErrDBListObjects
return "", ErrDBListObjects
}
// 如果上下文限制了父目录,则进行检查
@@ -43,14 +46,14 @@ func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs,
// 检查目录
for _, folder := range folders {
if *folder.ParentID != parent.ID {
return ErrObjectNotExist
return "", ErrObjectNotExist
}
}
// 检查文件
for _, file := range files {
if file.FolderID != parent.ID {
return ErrObjectNotExist
return "", ErrObjectNotExist
}
}
}
@@ -70,8 +73,25 @@ func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs,
files[i].Position = ""
}
// 创建临时压缩文件
saveFolder := "archive"
if !isArchive {
saveFolder = "compress"
}
zipFilePath := filepath.Join(
util.RelativePath(model.GetSettingByName("temp_path")),
saveFolder,
fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
)
zipFile, err := util.CreatNestedFile(zipFilePath)
if err != nil {
util.Log().Warning("%s", err)
return "", err
}
defer zipFile.Close()
// 创建压缩文件Writer
zipWriter := zip.NewWriter(writer)
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
ctx = reqContext
@@ -81,9 +101,10 @@ func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs,
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled
default:
fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive)
fs.doCompress(ctx, nil, &folders[i], zipWriter, isArchive)
}
}
@@ -91,13 +112,22 @@ func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs,
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath)
return "", ErrClientCanceled
default:
fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive)
fs.doCompress(ctx, &files[i], nil, zipWriter, isArchive)
}
}
return nil
return zipFilePath, nil
}
// cancelCompress 取消压缩进程
func (fs *FileSystem) cancelCompress(ctx context.Context, zipWriter *zip.Writer, file *os.File, path string) {
util.Log().Debug("客户端取消压缩请求")
zipWriter.Close()
file.Close()
_ = os.Remove(path)
}
func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
@@ -107,7 +137,7 @@ func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *
fs.Policy = file.GetPolicy()
err := fs.DispatchHandler()
if err != nil {
util.Log().Warning("Failed to compress file %q: %s", file.Name, err)
util.Log().Warning("无法压缩文件%s%s", file.Name, err)
return
}
@@ -117,7 +147,7 @@ func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *
file.SourceName,
)
if err != nil {
util.Log().Debug("Failed to open %q: %s", file.Name, err)
util.Log().Debug("Open%s%s", file.Name, err)
return
}
if closer, ok := fileToZip.(io.Closer); ok {
@@ -165,7 +195,7 @@ func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *
}
// Decompress 解压缩给定压缩文件到dst目录
func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error {
func (fs *FileSystem) Decompress(ctx context.Context, src, dst string) error {
err := fs.ResetFileIfNotExist(ctx, src)
if err != nil {
return err
@@ -176,7 +206,7 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
// 结束时删除临时压缩文件
if tempZipFilePath != "" {
if err := os.Remove(tempZipFilePath); err != nil {
util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err)
util.Log().Warning("无法删除临时压缩文件 %s , %s", tempZipFilePath, err)
}
}
}()
@@ -187,8 +217,6 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
return err
}
defer fileStream.Close()
tempZipFilePath = filepath.Join(
util.RelativePath(model.GetSettingByName("temp_path")),
"decompress",
@@ -197,47 +225,26 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
zipFile, err := util.CreatNestedFile(tempZipFilePath)
if err != nil {
util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err)
util.Log().Warning("无法创建临时压缩文件 %s , %s", tempZipFilePath, err)
tempZipFilePath = ""
return err
}
defer zipFile.Close()
// 下载前先判断是否是可解压的格式
format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream)
_, err = io.Copy(zipFile, fileStream)
if err != nil {
util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err)
util.Log().Warning("无法写入临时压缩文件 %s , %s", tempZipFilePath, err)
return err
}
extractor, ok := format.(archiver.Extractor)
if !ok {
return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName)
}
// 只有zip格式可以多个文件同时上传
var isZip bool
switch extractor.(type) {
case archiver.Zip:
extractor = archiver.Zip{TextEncoding: encoding}
isZip = true
}
// 除了zip必须下载到本地其余的可以边下载边解压
reader := readStream
if isZip {
_, err = io.Copy(zipFile, readStream)
if err != nil {
util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err)
return err
}
fileStream.Close()
// 设置文件偏移量
zipFile.Seek(0, io.SeekStart)
reader = zipFile
zipFile.Close()
// 解压缩文件
r, err := zip.OpenReader(tempZipFilePath)
if err != nil {
return err
}
defer r.Close()
// 重设存储策略
fs.Policy = &fs.User.Policy
@@ -253,64 +260,59 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
worker <- i
}
// 上传文件函数
uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) {
defer func() {
if isZip {
worker <- 1
wg.Done()
}
if err := recover(); err != nil {
util.Log().Warning("Error while uploading files inside of archive file.")
fmt.Println(err)
}
}()
err := fs.UploadFromStream(ctx, &fsctx.FileStream{
File: fileStream,
Size: uint64(size),
Name: path.Base(savePath),
VirtualPath: path.Dir(savePath),
}, true)
fileStream.Close()
if err != nil {
util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err)
for _, f := range r.File {
fileName := f.Name
// 处理非UTF-8编码
if f.NonUTF8 {
i := bytes.NewReader([]byte(fileName))
decoder := transform.NewReader(i, simplifiedchinese.GB18030.NewDecoder())
content, _ := ioutil.ReadAll(decoder)
fileName = string(content)
}
}
// 解压缩文件回调函数如果出错会停止解压的下一步进行全部return nil
err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error {
rawPath := util.FormSlash(f.NameInArchive)
rawPath := util.FormSlash(fileName)
savePath := path.Join(dst, rawPath)
// 路径是否合法
if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) {
util.Log().Warning("%s: illegal file path", f.NameInArchive)
return nil
return fmt.Errorf("%s: illegal file path", f.Name)
}
// 如果是目录
if f.FileInfo.IsDir() {
if f.FileInfo().IsDir() {
fs.CreateDirectory(ctx, savePath)
return nil
continue
}
// 上传文件
fileStream, err := f.Open()
if err != nil {
util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err)
return nil
util.Log().Warning("无法打开压缩包内文件%s , %s , 跳过", rawPath, err)
continue
}
if !isZip {
uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
} else {
<-worker
select {
case <-worker:
wg.Add(1)
go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
go func(fileStream io.ReadCloser, size int64) {
defer func() {
worker <- 1
wg.Done()
if err := recover(); err != nil {
util.Log().Warning("上传压缩包内文件时出错")
fmt.Println(err)
}
}()
err = fs.UploadFromStream(ctx, fileStream, savePath, uint64(size))
fileStream.Close()
if err != nil {
util.Log().Debug("无法上传压缩包内的文件%s , %s , 跳过", rawPath, err)
}
}(fileStream, f.FileInfo().Size())
}
return nil
})
}
wg.Wait()
return err
return nil
}

View File

@@ -1,15 +1,10 @@
package filesystem
import (
"bytes"
"context"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
testMock "github.com/stretchr/testify/mock"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
@@ -17,8 +12,11 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestFileSystem_Compress(t *testing.T) {
@@ -60,11 +58,12 @@ func TestFileSystem_Compress(t *testing.T) {
)
// 查找上传策略
asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.NoError(err)
asserts.NotEmpty(w.Len())
asserts.NotEmpty(zipFile)
asserts.Contains(zipFile, "archive_")
asserts.Contains(zipFile, "tests")
}
// 上下文取消
@@ -85,10 +84,9 @@ func TestFileSystem_Compress(t *testing.T) {
)
asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.Error(err)
asserts.NotEmpty(w.Len())
asserts.Empty(zipFile)
}
// 限制父目录
@@ -110,11 +108,10 @@ func TestFileSystem_Compress(t *testing.T) {
)
asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
zipFile, err := fs.Compress(ctx, []uint{1}, []uint{1}, true)
asserts.Error(err)
asserts.Equal(ErrObjectNotExist, err)
asserts.Empty(w.Len())
asserts.Empty(zipFile)
}
}
@@ -149,24 +146,12 @@ func (m MockRSC) Close() error {
return nil
}
var basepath string
func init() {
_, currentFile, _, _ := runtime.Caller(0)
basepath = filepath.Dir(currentFile)
}
func Path(rel string) string {
return filepath.Join(basepath, rel)
}
func TestFileSystem_Decompress(t *testing.T) {
asserts := assert.New(t)
ctx := context.Background()
fs := FileSystem{
User: &model.User{Model: gorm.Model{ID: 1}},
}
os.RemoveAll(util.RelativePath("tests/decompress"))
// 压缩文件不存在
{
@@ -176,7 +161,7 @@ func TestFileSystem_Decompress(t *testing.T) {
// 查找压缩文件,未找到
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
err := fs.Decompress(ctx, "/1.zip", "/", "")
err := fs.Decompress(ctx, "/1.zip", "/")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
@@ -186,9 +171,9 @@ func TestFileSystem_Decompress(t *testing.T) {
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, errors.New("error"))
testHandler.On("Get", testMock.Anything, "1.zip").Return(request.NopRSCloser{}, errors.New("error"))
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
err := fs.Decompress(ctx, "/1.zip", "/")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualError(err, "error")
@@ -200,9 +185,9 @@ func TestFileSystem_Decompress(t *testing.T) {
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, nil)
testHandler.On("Get", testMock.Anything, "1.zip").Return(request.NopRSCloser{}, nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
err := fs.Decompress(ctx, "/1.zip", "/")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
@@ -215,13 +200,13 @@ func TestFileSystem_Decompress(t *testing.T) {
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
err := fs.Decompress(ctx, "/1.zip", "/")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Contains(err.Error(), "read error")
asserts.EqualError(err, "read error")
}
// 无法重设上传策略
// 无效zip文件
{
cache.Set("setting_temp_path", "tests", 0)
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
@@ -229,7 +214,22 @@ func TestFileSystem_Decompress(t *testing.T) {
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
err := fs.Decompress(ctx, "/1.zip", "/")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualError(err, "zip: not a valid zip file")
}
// 无法重设上传策略
{
zipFile, _ := os.Open(util.RelativePath("filesystem/tests/test.zip"))
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/")
zipFile.Close()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.True(util.IsEmpty(util.RelativePath("tests/decompress")))
@@ -238,7 +238,7 @@ func TestFileSystem_Decompress(t *testing.T) {
// 无法上传,容量不足
{
cache.Set("setting_max_parallel_transfer", "1", 0)
zipFile, _ := os.Open(Path("tests/test.zip"))
zipFile, _ := os.Open(util.RelativePath("filesystem/tests/test.zip"))
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
fs.User.Policy.Type = "mock"
@@ -246,7 +246,7 @@ func TestFileSystem_Decompress(t *testing.T) {
testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil)
fs.Handler = testHandler
fs.Decompress(ctx, "/1.zip", "/", "")
fs.Decompress(ctx, "/1.zip", "/")
zipFile.Close()

View File

@@ -1,74 +0,0 @@
package backoff
import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strconv"
"time"
)
// Backoff used for retry sleep backoff
type Backoff interface {
Next(err error) bool
Reset()
}
// ConstantBackoff implements Backoff interface with constant sleep time. If the error
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
type ConstantBackoff struct {
Sleep time.Duration
Max int
tried int
}
func (c *ConstantBackoff) Next(err error) bool {
c.tried++
if c.tried > c.Max {
return false
}
var e *RetryableError
if errors.As(err, &e) && e.RetryAfter > 0 {
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
time.Sleep(e.RetryAfter)
} else {
time.Sleep(c.Sleep)
}
return true
}
func (c *ConstantBackoff) Reset() {
c.tried = 0
}
type RetryableError struct {
Err error
RetryAfter time.Duration
}
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
// and existing error.
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
retryAfter := header.Get("retry-after")
if retryAfter == "" {
retryAfter = "0"
}
res := &RetryableError{
Err: err,
}
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
}
return res
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
}

View File

@@ -1,61 +0,0 @@
package backoff
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)
func TestConstantBackoff_Next(t *testing.T) {
a := assert.New(t)
// General error
{
err := errors.New("error")
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
// Retryable error
{
err := &RetryableError{RetryAfter: time.Duration(1)}
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
}
func TestNewRetryableErrorFromHeader(t *testing.T) {
a := assert.New(t)
// no retry-after header
{
err := NewRetryableErrorFromHeader(nil, http.Header{})
a.Empty(err.RetryAfter)
}
// with retry-after header
{
header := http.Header{}
header.Add("retry-after", "120")
err := NewRetryableErrorFromHeader(nil, header)
a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter)
}
}

View File

@@ -1,167 +0,0 @@
package chunk
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"os"
)
const bufferTempPattern = "cdChunk.*.tmp"
// ChunkProcessFunc callback function for processing a chunk
type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
// ChunkGroup manage groups of chunks
type ChunkGroup struct {
file fsctx.FileHeader
chunkSize uint64
backoff backoff.Backoff
enableRetryBuffer bool
fileInfo *fsctx.UploadTaskInfo
currentIndex int
chunkNum uint64
bufferTemp *os.File
}
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup {
c := &ChunkGroup{
file: file,
chunkSize: chunkSize,
backoff: backoff,
fileInfo: file.Info(),
currentIndex: -1,
enableRetryBuffer: useBuffer,
}
if c.chunkSize == 0 {
c.chunkSize = c.fileInfo.Size
}
if c.fileInfo.Size == 0 {
c.chunkNum = 1
} else {
c.chunkNum = c.fileInfo.Size / c.chunkSize
if c.fileInfo.Size%c.chunkSize != 0 {
c.chunkNum++
}
}
return c
}
// TempAvailable returns if current chunk temp file is available to be read
func (c *ChunkGroup) TempAvailable() bool {
if c.bufferTemp != nil {
state, _ := c.bufferTemp.Stat()
return state != nil && state.Size() == c.Length()
}
return false
}
// Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
reader := io.LimitReader(c.file, c.Length())
// If useBuffer is enabled, tee the reader to a temp file
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern)
reader = io.TeeReader(reader, c.bufferTemp)
}
if c.bufferTemp != nil {
defer func() {
if c.bufferTemp != nil {
c.bufferTemp.Close()
os.Remove(c.bufferTemp.Name())
c.bufferTemp = nil
}
}()
// if temp buffer file is available, use it
if c.TempAvailable() {
if _, err := c.bufferTemp.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek temp file back to chunk start: %w", err)
}
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
reader = io.NopCloser(c.bufferTemp)
}
}
err := processor(c, reader)
if err != nil {
if c.enableRetryBuffer {
request.BlackHole(reader)
}
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
if c.file.Seekable() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)
}
}
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
return c.Process(processor)
}
return err
}
util.Log().Debug("Chunk %d processed", c.currentIndex)
return nil
}
// Start returns the byte index of current chunk
func (c *ChunkGroup) Start() int64 {
return int64(uint64(c.Index()) * c.chunkSize)
}
// Total returns the total length
func (c *ChunkGroup) Total() int64 {
return int64(c.fileInfo.Size)
}
// Num returns the total chunk number
func (c *ChunkGroup) Num() int {
return int(c.chunkNum)
}
// RangeHeader returns header value of Content-Range
func (c *ChunkGroup) RangeHeader() string {
return fmt.Sprintf("bytes %d-%d/%d", c.Start(), c.Start()+c.Length()-1, c.Total())
}
// Index returns current chunk index, starts from 0
func (c *ChunkGroup) Index() int {
return c.currentIndex
}
// Next switch to next chunk, returns whether all chunks are processed
func (c *ChunkGroup) Next() bool {
c.currentIndex++
c.backoff.Reset()
return c.currentIndex < int(c.chunkNum)
}
// Length returns the length of current chunk
func (c *ChunkGroup) Length() int64 {
contentLength := c.chunkSize
if c.Index() == int(c.chunkNum-1) {
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
}
return int64(contentLength)
}
// IsLast returns if current chunk is the last one
func (c *ChunkGroup) IsLast() bool {
return c.Index() == int(c.chunkNum-1)
}

View File

@@ -1,250 +0,0 @@
package chunk
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/stretchr/testify/assert"
"io"
"os"
"strings"
"testing"
)
func TestNewChunkGroup(t *testing.T) {
a := assert.New(t)
testCases := []struct {
fileSize uint64
chunkSize uint64
expectedInnerChunkSize uint64
expectedChunkNum uint64
expectedInfo [][2]int //Start, Index,Length
}{
{10, 0, 10, 1, [][2]int{{0, 10}}},
{0, 0, 0, 1, [][2]int{{0, 0}}},
{0, 10, 10, 1, [][2]int{{0, 0}}},
{50, 10, 10, 5, [][2]int{
{0, 10},
{10, 10},
{20, 10},
{30, 10},
{40, 10},
}},
{50, 50, 50, 1, [][2]int{
{0, 50},
}},
{50, 15, 15, 4, [][2]int{
{0, 15},
{15, 15},
{30, 15},
{45, 5},
}},
}
for index, testCase := range testCases {
file := &fsctx.FileStream{Size: testCase.fileSize}
chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true)
a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
"TestCase:%d,ChunkNum()", index)
a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize,
"TestCase:%d,InnerChunkSize()", index)
a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
"TestCase:%d,len(Chunks)", index)
a.EqualValues(testCase.fileSize, chunkGroup.Total())
for cIndex, info := range testCase.expectedInfo {
a.True(chunkGroup.Next())
a.EqualValues(info[1], chunkGroup.Length(),
"TestCase:%d,Chunks[%d].Length()", index, cIndex)
a.EqualValues(info[0], chunkGroup.Start(),
"TestCase:%d,Chunks[%d].Start()", index, cIndex)
a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(),
"TestCase:%d,Chunks[%d].IsLast()", index, cIndex)
a.NotEmpty(chunkGroup.RangeHeader())
}
a.False(chunkGroup.Next())
}
}
func TestChunkGroup_TempAvailablet(t *testing.T) {
a := assert.New(t)
file := &fsctx.FileStream{Size: 1}
c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true)
a.False(c.TempAvailable())
f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*")
defer func() {
f.Close()
os.Remove(f.Name())
}()
a.NoError(err)
c.bufferTemp = f
a.False(c.TempAvailable())
f.Write([]byte("1"))
a.True(c.TempAvailable())
}
func TestChunkGroup_Process(t *testing.T) {
a := assert.New(t)
file := &fsctx.FileStream{Size: 10}
// success
{
file.File = io.NopCloser(strings.NewReader("1234567890"))
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
return nil
}))
a.False(c.Next())
a.Equal(2, count)
}
// retry, read from buffer file
{
file.File = io.NopCloser(strings.NewReader("1234567890"))
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(3, count)
}
// retry, read from seeker
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(3, count)
}
// retry, seek error
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
f.Close()
a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(2, count)
}
// retry, finally error
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
return errors.New("error")
}))
a.False(c.Next())
a.Equal(4, count)
}
}

View File

@@ -183,11 +183,9 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
opt := &cossdk.ObjectPutOptions{}
_, err := handler.Client.Object.Put(ctx, file.Info().SavePath, file, opt)
_, err := handler.Client.Object.Put(ctx, dst, file, opt)
return err
}
@@ -218,7 +216,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
return failed, nil
}
return failed, errors.New("delete failed")
return failed, errors.New("删除失败")
}
// Thumb 获取文件缩略图
@@ -326,16 +324,21 @@ func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64,
}
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
func (handler Driver) Token(ctx context.Context, TTL int64, key string) (serializer.UploadCredential, error) {
// 读取上下文中生成的存储路径
savePath, ok := ctx.Value(fsctx.SavePathCtx).(string)
if !ok {
return serializer.UploadCredential{}, errors.New("无法获取存储路径")
}
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key)
apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + key)
apiURL := siteURL.ResolveReference(apiBaseURI).String()
// 上传策略
savePath := file.Info().SavePath
startTime := time.Now()
endTime := startTime.Add(time.Duration(ttl) * time.Second)
endTime := startTime.Add(time.Duration(TTL) * time.Second)
keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix())
postPolicy := UploadPolicy{
Expiration: endTime.UTC().Format(time.RFC3339),
@@ -343,7 +346,7 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria
map[string]string{"bucket": handler.Policy.BucketName},
map[string]string{"$key": savePath},
map[string]string{"x-cos-meta-callback": apiURL},
map[string]string{"x-cos-meta-key": uploadSession.Key},
map[string]string{"x-cos-meta-key": key},
map[string]string{"q-sign-algorithm": "sha1"},
map[string]string{"q-ak": handler.Policy.AccessKey},
map[string]string{"q-sign-time": keyTime},
@@ -355,22 +358,16 @@ func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *seria
[]interface{}{"content-length-range", 0, handler.Policy.MaxSize})
}
res, err := handler.getUploadCredential(ctx, postPolicy, keyTime, savePath)
res, err := handler.getUploadCredential(ctx, postPolicy, keyTime)
if err == nil {
res.SessionID = uploadSession.Key
res.Callback = apiURL
res.UploadURLs = []string{handler.Policy.Server}
res.Key = key
}
return res, err
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
// Meta 获取文件信息
func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.Client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{})
@@ -384,11 +381,17 @@ func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error)
}, nil
}
func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string, savePath string) (*serializer.UploadCredential, error) {
func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string) (serializer.UploadCredential, error) {
// 读取上下文中生成的存储路径
savePath, ok := ctx.Value(fsctx.SavePathCtx).(string)
if !ok {
return serializer.UploadCredential{}, errors.New("无法获取存储路径")
}
// 编码上传策略
policyJSON, err := json.Marshal(policy)
if err != nil {
return nil, err
return serializer.UploadCredential{}, err
}
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
@@ -396,14 +399,14 @@ func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPoli
hmacSign := hmac.New(sha1.New, []byte(handler.Policy.SecretKey))
_, err = io.WriteString(hmacSign, keyTime)
if err != nil {
return nil, err
return serializer.UploadCredential{}, err
}
signKey := fmt.Sprintf("%x", hmacSign.Sum(nil))
sha1Sign := sha1.New()
_, err = sha1Sign.Write(policyJSON)
if err != nil {
return nil, err
return serializer.UploadCredential{}, err
}
stringToSign := fmt.Sprintf("%x", sha1Sign.Sum(nil))
@@ -411,15 +414,15 @@ func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPoli
hmacFinalSign := hmac.New(sha1.New, []byte(signKey))
_, err = hmacFinalSign.Write([]byte(stringToSign))
if err != nil {
return nil, err
return serializer.UploadCredential{}, err
}
signature := hmacFinalSign.Sum(nil)
return &serializer.UploadCredential{
Policy: policyEncoded,
Path: savePath,
AccessKey: handler.Policy.AccessKey,
Credential: fmt.Sprintf("%x", signature),
KeyTime: keyTime,
return serializer.UploadCredential{
Policy: policyEncoded,
Path: savePath,
AccessKey: handler.Policy.AccessKey,
Token: fmt.Sprintf("%x", signature),
KeyTime: keyTime,
}, nil
}

View File

@@ -2,9 +2,9 @@ package driver
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io"
"net/url"
)
@@ -12,7 +12,7 @@ import (
type Handler interface {
// 上传文件, dst为文件存储路径size 为文件大小。上下文关闭
// 时,应取消上传并清理临时文件
Put(ctx context.Context, file fsctx.FileHeader) error
Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
Delete(ctx context.Context, files []string) ([]string, error)
@@ -29,11 +29,8 @@ type Handler interface {
// isDownload - 是否直接下载
Source(ctx context.Context, path string, url url.URL, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)
// CancelToken 取消已经创建的有状态上传凭证
CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error
// Token 获取有效期为ttl的上传凭证和签名同时回调会话有效期为sessionTTL
Token(ctx context.Context, ttl int64, callbackKey string) (serializer.UploadCredential, error)
// List 递归列取远程端path路径下文件、目录不包含path本身
// 返回的对象路径以path作为起始根目录.

View File

@@ -0,0 +1,38 @@
package local
import (
"io"
)
// FileStream 用户传来的文件
type FileStream struct {
File io.ReadCloser
Size uint64
VirtualPath string
Name string
MIMEType string
}
func (file FileStream) Read(p []byte) (n int, err error) {
return file.File.Read(p)
}
func (file FileStream) GetMIMEType() string {
return file.MIMEType
}
func (file FileStream) GetSize() uint64 {
return file.Size
}
func (file FileStream) Close() error {
return file.File.Close()
}
func (file FileStream) GetFileName() string {
return file.Name
}
func (file FileStream) GetVirtualPath() string {
return file.VirtualPath
}

View File

@@ -0,0 +1,48 @@
package local
import (
"github.com/stretchr/testify/assert"
"io/ioutil"
"strings"
"testing"
)
func TestFileStream_GetFileName(t *testing.T) {
asserts := assert.New(t)
file := FileStream{Name: "123"}
asserts.Equal("123", file.GetFileName())
}
func TestFileStream_GetMIMEType(t *testing.T) {
asserts := assert.New(t)
file := FileStream{MIMEType: "123"}
asserts.Equal("123", file.GetMIMEType())
}
func TestFileStream_GetSize(t *testing.T) {
asserts := assert.New(t)
file := FileStream{Size: 123}
asserts.Equal(uint64(123), file.GetSize())
}
func TestFileStream_Read(t *testing.T) {
asserts := assert.New(t)
file := FileStream{
File: ioutil.NopCloser(strings.NewReader("123")),
}
var p = make([]byte, 3)
{
n, err := file.Read(p)
asserts.Equal(3, n)
asserts.NoError(err)
}
}
func TestFileStream_Close(t *testing.T) {
asserts := assert.New(t)
file := FileStream{
File: ioutil.NopCloser(strings.NewReader("123")),
}
err := file.Close()
asserts.NoError(err)
}

Some files were not shown because too many files have changed in this diff Show More