Files
siyuan/kernel/api/network.go
mm-o dda0fb9136 🎨 Add Generic Streamed Proxy Responses for Kernel Plugins (#17748)
* Add streaming proxy response for kernel plugins

* Harden kernel plugin streaming proxy
2026-05-29 09:43:21 +08:00

608 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package api
import (
"encoding/base32"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/textproto"
"net/url"
"strings"
"time"
"github.com/88250/gulu"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/imroc/req/v3"
"github.com/siyuan-note/logging"
"github.com/siyuan-note/siyuan/kernel/util"
)
type File struct {
Filename string
Header textproto.MIMEHeader
Size int64
Content string
}
type MultipartForm struct {
Value map[string][]string
File map[string][]File
}
func echo(c *gin.Context) {
ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret)
var (
password string
multipartForm *MultipartForm
rawData any
)
if form, err := c.MultipartForm(); err != nil || nil == form {
multipartForm = nil
} else {
multipartForm = &MultipartForm{
Value: form.Value,
File: map[string][]File{},
}
for k, handlers := range form.File {
files := make([]File, len(handlers))
multipartForm.File[k] = files
for i, handler := range handlers {
files[i].Filename = handler.Filename
files[i].Header = handler.Header
files[i].Size = handler.Size
if file, err := handler.Open(); err != nil {
logging.LogWarnf("echo open form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
} else {
content := make([]byte, handler.Size)
if n, err := file.Read(content); err != nil {
logging.LogWarnf("echo read form [%s] file [%s] error: %s", k, handler.Filename, err.Error())
} else {
files[i].Content = base64.StdEncoding.EncodeToString(content[:n])
}
}
}
}
}
if data, err := c.GetRawData(); err == nil {
rawData = base64.StdEncoding.EncodeToString(data)
} else {
logging.LogWarnf("echo get raw data error: %s", err.Error())
rawData = nil
}
username, password, ok := c.Request.BasicAuth()
ret.Data = map[string]any{
"Context": map[string]any{
"Params": c.Params,
"HandlerNames": c.HandlerNames(),
"FullPath": c.FullPath(),
"ClientIP": c.ClientIP(),
"RemoteIP": c.RemoteIP(),
"ContentType": c.ContentType(),
"IsWebsocket": c.IsWebsocket(),
"RawData": rawData,
},
"Request": map[string]any{
"Method": c.Request.Method,
"URL": c.Request.URL,
"Proto": c.Request.Proto,
"ProtoMajor": c.Request.ProtoMajor,
"ProtoMinor": c.Request.ProtoMinor,
"Header": c.Request.Header,
"ContentLength": c.Request.ContentLength,
"TransferEncoding": c.Request.TransferEncoding,
"Close": c.Request.Close,
"Host": c.Request.Host,
"Form": c.Request.Form,
"PostForm": c.Request.PostForm,
"MultipartForm": multipartForm,
"Trailer": c.Request.Trailer,
"RemoteAddr": c.Request.RemoteAddr,
"TLS": c.Request.TLS,
"UserAgent": c.Request.UserAgent(),
"Cookies": c.Request.Cookies(),
"Referer": c.Request.Referer(),
},
"URL": map[string]any{
"EscapedPath": c.Request.URL.EscapedPath(),
"EscapedFragment": c.Request.URL.EscapedFragment(),
"String": c.Request.URL.String(),
"Redacted": c.Request.URL.Redacted(),
"IsAbs": c.Request.URL.IsAbs(),
"Query": c.Request.URL.Query(),
"RequestURI": c.Request.URL.RequestURI(),
"Hostname": c.Request.URL.Hostname(),
"Port": c.Request.URL.Port(),
},
"User": map[string]any{
"Exists": ok,
"Username": username,
"Password": password,
},
}
}
func forwardProxy(c *gin.Context) {
ret := gulu.Ret.NewResult()
defer c.JSON(http.StatusOK, ret)
arg, ok := util.JsonArg(c, ret)
if !ok {
return
}
var destURL string
if !util.ParseJsonArgs(arg, ret, util.BindJsonArg("url", &destURL, true, true)) {
return
}
u, e := url.ParseRequestURI(destURL)
if nil != e {
ret.Code = 1
ret.Msg = "invalid [url]"
return
}
if u.Scheme != "http" && u.Scheme != "https" {
ret.Code = 2
ret.Msg = "only http/https is allowed"
return
}
method := "POST"
if methodArg := arg["method"]; nil != methodArg {
method = strings.ToUpper(methodArg.(string))
}
timeout := 7000
if timeoutArg := arg["timeout"]; nil != timeoutArg {
timeout = int(timeoutArg.(float64))
if 1 > timeout {
timeout = 7000
}
}
client := getSafeClient(time.Duration(timeout) * time.Millisecond)
if redirectArg, ok := arg["redirect"].(bool); ok && !redirectArg {
client.SetRedirectPolicy(req.NoRedirectPolicy())
}
request := client.R()
if headers, ok := arg["headers"].([]any); ok {
for _, pair := range headers {
if m, ok := pair.(map[string]any); ok {
for k, v := range m {
request.SetHeader(k, fmt.Sprintf("%v", v))
}
}
}
}
contentType := "application/json"
if contentTypeArg := arg["contentType"]; nil != contentTypeArg {
contentType = contentTypeArg.(string)
}
request.SetHeader("Content-Type", contentType)
payloadEncoding := "json"
if payloadEncodingArg := arg["payloadEncoding"]; nil != payloadEncodingArg {
payloadEncoding = payloadEncodingArg.(string)
}
switch payloadEncoding {
case "base64":
fallthrough
case "base64-std":
if payload, err := base64.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = 3
ret.Msg = "decode base64-std payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base64-url":
if payload, err := base64.URLEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = 4
ret.Msg = "decode base64-url payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base32":
fallthrough
case "base32-std":
if payload, err := base32.StdEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = 5
ret.Msg = "decode base32-std payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "base32-hex":
if payload, err := base32.HexEncoding.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = 6
ret.Msg = "decode base32-hex payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "hex":
if payload, err := hex.DecodeString(arg["payload"].(string)); err != nil {
ret.Code = 7
ret.Msg = "decode hex payload failed: " + err.Error()
return
} else {
request.SetBody(payload)
}
case "text":
default:
request.SetBody(arg["payload"])
}
started := time.Now()
resp, err := request.Send(method, destURL)
if err != nil {
ret.Code = 8
ret.Msg = "forward request failed: " + err.Error()
return
}
bodyData, err := io.ReadAll(resp.Body)
if err != nil {
ret.Code = 9
ret.Msg = "read response body failed: " + err.Error()
return
}
elapsed := time.Since(started)
responseEncoding := "text"
if responseEncodingArg := arg["responseEncoding"]; nil != responseEncodingArg {
responseEncoding = responseEncodingArg.(string)
}
body := ""
switch responseEncoding {
case "base64":
fallthrough
case "base64-std":
body = base64.StdEncoding.EncodeToString(bodyData)
case "base64-url":
body = base64.URLEncoding.EncodeToString(bodyData)
case "base32":
fallthrough
case "base32-std":
body = base32.StdEncoding.EncodeToString(bodyData)
case "base32-hex":
body = base32.HexEncoding.EncodeToString(bodyData)
case "hex":
body = hex.EncodeToString(bodyData)
case "text":
fallthrough
default:
responseEncoding = "text"
body = string(bodyData)
}
data := map[string]any{
"url": destURL,
"status": resp.StatusCode,
"contentType": resp.GetHeader("content-type"),
"body": body,
"bodyEncoding": responseEncoding,
"headers": resp.Header,
"elapsed": elapsed.Milliseconds(),
}
ret.Data = data
//shortBody := ""
//if 64 > len(body) {
// shortBody = body
//} else {
// shortBody = body[:64]
//}
//
//logging.LogInfof("elapsed [%.1fs], length [%d], request [url=%s, headers=%s, content-type=%s, body=%s], status [%d], body [%s]",
// elapsed.Seconds(), len(bodyData), data["url"], headers, contentType, arg["payload"], data["status"], shortBody)
}
// 创建安全的 HTTP Client防止 SSRF 和 DNS 重绑定
func getSafeClient(timeout time.Duration) *req.Client {
dialer := util.SSRFSafeDialer(timeout)
client := req.C()
client.SetTimeout(timeout)
client.SetDial(dialer.DialContext)
client.SetRedirectPolicy(req.MaxRedirectPolicy(3))
return client
}
// parseForwardProxyParams decodes the `u` and `h` query parameters.
//
// Query params:
// - `u`: RawURLEncoding base64 of the target URL string.
// - `h`: RawURLEncoding base64 of a JSON object map[string][]string.
func parseForwardProxyParams(c *gin.Context) (parsedURL *url.URL, headers *http.Header, err error) {
uParam := c.Query("u")
if uParam == "" {
err = fmt.Errorf("missing query param [u]")
return
}
uBytes, decErr := base64.RawURLEncoding.DecodeString(uParam)
if decErr != nil {
err = fmt.Errorf("decode [u] failed: %s", decErr.Error())
return
}
parsedURL, err = url.ParseRequestURI(string(uBytes))
if err != nil {
err = fmt.Errorf("parse [u] failed: %s", err.Error())
return
}
h := http.Header{}
headers = &h
hParam := c.Query("h")
if hParam == "" {
return
}
hBytes, decErr := base64.RawURLEncoding.DecodeString(hParam)
if decErr != nil {
err = fmt.Errorf("decode [h] failed: %s", decErr.Error())
return
}
var record map[string][]string
if jsonErr := json.Unmarshal(hBytes, &record); jsonErr != nil {
err = fmt.Errorf("parse [h] failed: %s", jsonErr.Error())
return
}
for k, vs := range record {
for _, v := range vs {
h.Add(k, v)
}
}
return
}
// forwardResponseHeaders copies src headers into dst with a "Siyuan-Proxy-" prefix on each key.
func forwardResponseHeaders(dst http.Header, src http.Header) {
for k, vs := range src {
for _, v := range vs {
dst.Add("Siyuan-Proxy-"+k, v)
}
}
}
// httpProxy proxies an HTTP request to a remote HTTP endpoint.
//
// Query params:
// - u: RawURLEncoding base64 of the target http/https URL
// - h: RawURLEncoding base64 of JSON map[string][]string forwarded as request headers
//
// The request method and body are taken from the incoming request.
// Target response headers are forwarded with a "Siyuan-Proxy-" prefix.
func httpProxy(c *gin.Context) {
targetURL, targetHeaders, err := parseForwardProxyParams(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": err.Error()})
return
}
if targetURL.Scheme != "http" && targetURL.Scheme != "https" {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": "only http/https is allowed"})
return
}
transport := &http.Transport{
DialContext: util.SSRFSafeDialer(30 * time.Second).DialContext,
}
httpClient := &http.Client{Transport: transport}
proxyReq, reqErr := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), c.Request.Body)
if reqErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": "create request failed: " + reqErr.Error()})
return
}
proxyReq.ContentLength = c.Request.ContentLength
if c.ContentType() != "" {
proxyReq.Header.Set("Content-Type", c.ContentType())
}
for k, vs := range *targetHeaders {
for _, v := range vs {
proxyReq.Header.Add(k, v)
}
}
resp, respErr := httpClient.Do(proxyReq)
if respErr != nil {
c.JSON(http.StatusBadGateway, gin.H{"code": -1, "msg": "connect target failed: " + respErr.Error()})
return
}
defer resp.Body.Close()
forwardResponseHeaders(c.Writer.Header(), resp.Header)
c.Writer.WriteHeader(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
logging.LogWarnf("http proxy copy response failed: %s", err.Error())
}
}
// wsProxy proxies a WebSocket connection to a remote WebSocket endpoint.
//
// Query params:
// - u: RawURLEncoding base64 of the target ws/wss URL
// - h: RawURLEncoding base64 of JSON map[string][]string forwarded as handshake headers
//
// Target response headers are forwarded with a "Siyuan-Proxy-" prefix.
func wsProxy(c *gin.Context) {
targetURL, targetHeaders, err := parseForwardProxyParams(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": err.Error()})
return
}
if targetURL.Scheme != "ws" && targetURL.Scheme != "wss" {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": "only ws/wss is allowed"})
return
}
wsDialer := &websocket.Dialer{
NetDialContext: util.SSRFSafeDialer(30 * time.Second).DialContext,
HandshakeTimeout: 30 * time.Second,
}
targetConn, targetResp, dialErr := wsDialer.DialContext(c.Request.Context(), targetURL.String(), *targetHeaders)
if dialErr != nil {
c.JSON(http.StatusBadGateway, gin.H{"code": -1, "msg": "dial target failed: " + dialErr.Error()})
return
}
defer targetConn.Close()
upgradeHeaders := http.Header{}
if targetResp != nil {
forwardResponseHeaders(upgradeHeaders, targetResp.Header)
}
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
clientConn, upgradeErr := upgrader.Upgrade(c.Writer, c.Request, upgradeHeaders)
if upgradeErr != nil {
logging.LogErrorf("ws forward proxy upgrade failed: %s", upgradeErr.Error())
return
}
defer clientConn.Close()
errChan := make(chan error, 2)
go func() {
for {
msgType, msg, readErr := targetConn.ReadMessage()
if readErr != nil {
if closeError, ok := readErr.(*websocket.CloseError); ok {
clientConn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(
closeError.Code,
closeError.Text,
),
)
}
errChan <- readErr
return
}
if writeErr := clientConn.WriteMessage(msgType, msg); writeErr != nil {
errChan <- writeErr
return
}
}
}()
go func() {
for {
msgType, msg, readErr := clientConn.ReadMessage()
if readErr != nil {
if closeError, ok := readErr.(*websocket.CloseError); ok {
targetConn.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(
closeError.Code,
closeError.Text,
),
)
}
errChan <- readErr
return
}
if writeErr := targetConn.WriteMessage(msgType, msg); writeErr != nil {
errChan <- writeErr
return
}
}
}()
<-errChan
}
// esProxy proxies an EventSource (SSE) stream from a remote HTTP endpoint.
//
// Query params:
// - u: RawURLEncoding base64 of the target http/https URL
// - h: RawURLEncoding base64 of JSON map[string][]string forwarded as request headers
//
// Target response headers are forwarded with a "Siyuan-Proxy-" prefix.
func esProxy(c *gin.Context) {
targetURL, targetHeaders, err := parseForwardProxyParams(c)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": err.Error()})
return
}
if targetURL.Scheme != "http" && targetURL.Scheme != "https" {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": "only http/https is allowed"})
return
}
transport := &http.Transport{
DialContext: util.SSRFSafeDialer(30 * time.Second).DialContext,
}
httpClient := &http.Client{Transport: transport}
proxyReq, reqErr := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, targetURL.String(), nil)
if reqErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": -1, "msg": "create request failed: " + reqErr.Error()})
return
}
for k, vs := range *targetHeaders {
for _, v := range vs {
proxyReq.Header.Add(k, v)
}
}
if proxyReq.Header.Get("Accept") == "" {
proxyReq.Header.Set("Accept", "text/event-stream")
}
resp, respErr := httpClient.Do(proxyReq)
if respErr != nil {
c.JSON(http.StatusBadGateway, gin.H{"code": -1, "msg": "connect target failed: " + respErr.Error()})
return
}
defer resp.Body.Close()
forwardResponseHeaders(c.Writer.Header(), resp.Header)
c.Writer.WriteHeader(resp.StatusCode)
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
if _, writeErr := c.Writer.Write(buf[:n]); writeErr != nil {
return
}
c.Writer.Flush()
}
if readErr != nil {
return
}
}
}