Files
siyuan/kernel/sql/stmt_validate.go

196 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package sql
import (
"context"
"errors"
"fmt"
"strings"
"unicode"
"github.com/mattn/go-sqlite3"
)
// tailIsOnlyWhitespaceOrSQLComments 判断分号之后的片段是否仅由空白、行注释(-- 至换行或 EOF
// 块注释(/* … */,含未闭合则吞至 EOF构成。与 SQLite 解析对齐:分号后若只有这些内容,不会被视为另一条可执行的 SQL 语句。
func tailIsOnlyWhitespaceOrSQLComments(s string) bool {
runes := []rune(s)
for i := 0; i < len(runes); {
if unicode.IsSpace(runes[i]) {
i++
continue
}
ch := runes[i]
var next rune
if i+1 < len(runes) {
next = runes[i+1]
}
if '-' == ch && '-' == next {
i += 2
for i < len(runes) && '\n' != runes[i] {
i++
}
continue
}
if '/' == ch && '*' == next {
i += 2
for i < len(runes) {
if '*' == runes[i] && i+1 < len(runes) && '/' == runes[i+1] {
i += 2
break
}
i++
}
continue
}
return false
}
return true
}
func containsMultipleStatements(stmt string) bool {
stmt = strings.TrimSpace(stmt)
for strings.HasSuffix(stmt, ";") {
stmt = strings.TrimRight(stmt, ";")
stmt = strings.TrimSpace(stmt)
}
inSingleQuote := false
inDoubleQuote := false
inBacktickQuote := false
inBracketQuote := false
inLineComment := false
inBlockComment := false
runes := []rune(stmt)
for i := 0; i < len(runes); i++ {
ch := runes[i]
var next rune
if i+1 < len(runes) {
next = runes[i+1]
}
if inLineComment {
if '\n' == ch {
inLineComment = false
}
continue
}
if inBlockComment {
if '*' == ch && '/' == next {
inBlockComment = false
i++
}
continue
}
if inSingleQuote {
if '\'' == ch {
inSingleQuote = false
}
continue
}
if inDoubleQuote {
if '"' == ch {
inDoubleQuote = false
}
continue
}
if inBacktickQuote {
if '`' == ch {
inBacktickQuote = false
}
continue
}
if inBracketQuote {
if ']' == ch {
inBracketQuote = false
}
continue
}
switch {
case '\'' == ch:
inSingleQuote = true
case '"' == ch:
inDoubleQuote = true
case '`' == ch:
inBacktickQuote = true
case '[' == ch:
inBracketQuote = true
case '-' == ch && next == '-':
inLineComment = true
i++
case '/' == ch && next == '*':
inBlockComment = true
i++
case ';' == ch:
tail := string(runes[i+1:])
if tailIsOnlyWhitespaceOrSQLComments(tail) {
// 分号后仅有空白与 SQL 注释时SQLite 仍视为同一条语句末尾,不应判为多语句。
continue
}
return true
}
}
return false
}
func CheckSingleStatement(stmt string) error {
if containsMultipleStatements(stmt) {
return errors.New("SQL statement is not single")
}
return nil
}
// CheckReadonlyStatement 对整段 SQL 做 prepare不执行用 sqlite3_stmt_readonly 判断首条语句是否只读。
// 见 https://sqlite.org/c3ref/stmt_readonly.html
//
// 注意:若字符串里在语法上还有第二条及以后的语句,本函数只针对「首条」对应的 stmt 做判断,
// 不会拒绝多语句。与 CheckSingleStatement 组合即可得到「单条 + 只读」策略。
func CheckReadonlyStatement(stmt string) error {
if strings.TrimSpace(stmt) == "" {
return errors.New("SQL statement is empty")
}
ctx := context.Background()
conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
return conn.Raw(func(dc any) error {
sqliteConn, ok := dc.(*sqlite3.SQLiteConn)
if !ok {
return fmt.Errorf("SQL driver connection type is unexpected: %T", dc)
}
ds, err := sqliteConn.Prepare(stmt)
if err != nil {
return err
}
defer ds.Close()
sst, ok := ds.(*sqlite3.SQLiteStmt)
if !ok {
return fmt.Errorf("SQL driver statement type is unexpected: %T", ds)
}
if !sst.Readonly() {
return errors.New("SQL statement is not read-only")
}
return nil
})
}