Files
siyuan/kernel/plugin/plugin.go
Yingyi / 颖逸 c59b8ec7f7 🎨 Harden RPC handling and improve error reporting in kernel plugin (#17655)
* perf(kernel-plugin): strengthen RPC, sandbox, and form parsing

Validate and harden plugin RPC and request handling: ensure RPC API call first argument is a string; treat missing method using HasValue(); return InvalidParams for malformed params; bail out early when kernel is incompatible or missing. Fix sandbox promise invocation to return after reporting errors to avoid continuing on nil/invalid values. Change RequestForm files to []*RequestFile, allocate pointer entries, properly open/read/close uploaded files, and clone request headers before modifying them. These changes prevent nil derefs, resource leaks, and improve error reporting.

* perf(kernel-plugin): Skip empty Content-Type; use safe type assertions

Avoid setting an empty Content-Type header in the proxy when gin.Context.ContentType() is empty. Replace unsafe type assertions with comma-ok checks when converting request and file body bytes to Data objects to prevent panics on unexpected types or nil pointers. Also comment out assignments of c.Request.Context().Err() in plugin request handlers to avoid overwriting other error state on context cancellation. Affected files: kernel/api/network.go, kernel/plugin/plugin.go, kernel/plugin/sandbox.go.
2026-05-09 19:38:54 +08:00

1235 lines
34 KiB
Go

// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package plugin
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"sync/atomic"
"github.com/asaskevich/EventBus"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/eventloop"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/lxzan/gws"
"github.com/samber/lo"
"github.com/siyuan-note/logging"
"github.com/siyuan-note/siyuan/kernel/model"
"github.com/siyuan-note/siyuan/kernel/util"
"github.com/smallnest/chanx"
)
type PluginState int64
type RpcMethod struct {
Name string
Descriptions []string
Method goja.Callable
}
type RpcMethodInfo struct {
Name string `json:"name"`
Descriptions []string `json:"descriptions"`
}
type R map[string]any
const (
EventBusTopicPlugin = "plugin" // Topic to kernel plugin
EventBusTopicRuntime = "runtime" // Topic to javascript runtime
)
const (
PluginStateReady PluginState = iota
PluginStateLoading
PluginStateLoaded
PluginStateRunning
PluginStateStopping
PluginStateStopped
PluginStateError
)
func (s PluginState) String() string {
switch s {
case PluginStateReady:
return "ready"
case PluginStateLoading:
return "loading"
case PluginStateLoaded:
return "loaded"
case PluginStateRunning:
return "running"
case PluginStateStopping:
return "stopping"
case PluginStateStopped:
return "stopped"
case PluginStateError:
return "error"
default:
return "unknown"
}
}
// KernelPlugin represents a single kernel-side plugin instance.
type KernelPlugin struct {
*model.Petal
token string // JWT for this plugin
file string // kernel.js file path named in js runtime (e.g. "plugin-name/kernel.js")
worker Worker // Worker for serializing plugin js-call-go (e.g. logger) and go-call-js (e.g. RPC calls) tasks on a single goroutine
runtime *eventloop.EventLoop // goja event loop runtime for this plugin
state atomic.Int64 // PluginState
context context.Context // Context for managing plugin lifecycle and cancellation
cancel context.CancelFunc // Cancel function for managing plugin lifecycle and cancellation
bus EventBus.Bus // Event bus for plugin events and RPC request/response dispatch
rpcMethods sync.Map // string -> *RpcMethod, registered JSON-RPC methods
socketsMu sync.RWMutex // mutex for gwsSockets map
sockets map[*gws.Conn]bool // tracked gws WebSocket connections (true: RPC server, false: regular)
}
func NewKernelPlugin(petal *model.Petal) *KernelPlugin {
token, err := model.CreatePluginJWT(petal.Name)
if err != nil {
logging.LogErrorf("Failed to create plugin JWT for [%s]: %v", petal.Name, err)
}
plugin := &KernelPlugin{
Petal: petal,
token: token,
file: fmt.Sprintf("%s/kernel.js", petal.Name),
bus: EventBus.New(),
sockets: make(map[*gws.Conn]bool),
}
plugin.state.Store(int64(PluginStateReady))
return plugin
}
// State returns the current plugin state (safe for concurrent reads).
func (p *KernelPlugin) State() PluginState {
return PluginState(p.state.Load())
}
// InitRuntime initializes the goja runtime and evaluates kernel.js.
func (p *KernelPlugin) InitRuntime() (err error) {
p.runtime = eventloop.NewEventLoop(eventloop.EnableConsole(true))
p.worker.Start(p.runtime)
p.runtime.Run(func(rt *goja.Runtime) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("goja panic during event loop run: %v", r)
}
}()
// Use JSON struct tags for field name mapping, with fallback to original names if "json" tag is absent.
rt.SetFieldNameMapper(goja.TagFieldNameMapper("json", true))
if enableErr := EnableExtendModules(p, rt); enableErr != nil {
err = fmt.Errorf("EnableExtendModules: %v", enableErr)
return
}
if enableErr := EnableSiyuanModule(p, rt); enableErr != nil {
err = fmt.Errorf("EnableSiyuanModule: %v", enableErr)
return
}
if _, runErr := rt.RunScript(p.file, p.Kernel.JS); runErr != nil {
err = fmt.Errorf("RunScript: %v", runErr)
return
}
})
p.runtime.Start()
return
}
// Eval evaluates JavaScript code in the plugin's goja runtime, returning the result or error.
func (p *KernelPlugin) Eval(rt *goja.Runtime, code string) (goja.Value, error) {
return rt.RunScript(p.file, code)
}
// close interrupts the goja runtime and clears the pointer.
func (p *KernelPlugin) close() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("goja panic during close runtime: %v", r)
}
}()
if p.runtime != nil {
p.runtime.Stop() // Stops the event loop and waits for it to finish.
// p.runtime.Terminate() // Interrupts the runtime and causes all executing code to throw an exception.
}
return
}
// error sets the plugin state to errored and frees the goja runtime.
func (p *KernelPlugin) error() {
if err := p.close(); err != nil {
logging.LogErrorf("[plugin:%s] failed to close runtime during error handling: %v", p.Name, err)
}
p.state.Store(int64(PluginStateError))
}
// start creates the goja runtime, injects sandbox globals, and evaluates kernel.js.
func (p *KernelPlugin) start() (err error) {
defer func() {
if r := recover(); r != nil {
p.error()
err = fmt.Errorf("goja panic during start: %v", r)
}
}()
p.state.Store(int64(PluginStateLoading))
p.context, p.cancel = context.WithCancel(context.Background())
baseDir := filepath.Join(util.DataDir, "storage", "petal", p.Name)
if err := os.MkdirAll(baseDir, 0755); err != nil {
return fmt.Errorf("create plugin dir [%s] failed: %s", baseDir, err)
}
if runtimeErr := p.InitRuntime(); runtimeErr != nil {
p.error()
return fmt.Errorf("start runtime: %v", runtimeErr)
}
if subscribeErr := p.subscribeEventHandlers(); subscribeErr != nil {
p.error()
return fmt.Errorf("subscribe plugin events: %v", subscribeErr)
}
p.onLoad()
p.state.Store(int64(PluginStateLoaded))
p.onLoaded()
p.state.Store(int64(PluginStateRunning))
p.onRunning()
p.bus.Publish(EventBusTopicRuntime, R{
"id": uuid.NewString(),
"type": "start",
})
logging.LogDebugf("[plugin:%s] started", p.Name)
return
}
// stop cleanly shuts down the plugin: closes sockets, frees goja runtime.
func (p *KernelPlugin) stop() (ok bool, err error) {
defer func() {
if r := recover(); r != nil {
p.error()
ok = false
err = fmt.Errorf("panic during plugin stop: %v", r)
}
}()
if p.State() != PluginStateRunning {
ok = false
return
}
p.bus.Publish(EventBusTopicRuntime, R{
"id": uuid.NewString(),
"type": "stop",
})
p.state.Store(int64(PluginStateStopping))
p.onUnload()
p.rpcMethods.Clear()
p.cancel()
p.socketsMu.Lock()
for c := range p.sockets {
delete(p.sockets, c)
}
p.socketsMu.Unlock()
p.unsubscribeEventHandlers()
p.close()
p.state.Store(int64(PluginStateStopped))
logging.LogDebugf("[plugin:%s] stopped", p.Name)
ok = true
return
}
// onLoad is called before plugin loaded.
func (p *KernelPlugin) onLoad() {
if p.State() == PluginStateLoading {
p.invokeHook("onload")
}
}
// onLoaded is called after plugin loaded.
func (p *KernelPlugin) onLoaded() {
if p.State() == PluginStateLoaded {
p.invokeHook("onloaded")
}
}
// onRunning is called after plugin running.
func (p *KernelPlugin) onRunning() {
if p.State() == PluginStateRunning {
p.invokeHook("onrunning")
}
}
// onUnload is called before plugin stop.
func (p *KernelPlugin) onUnload() {
if p.State() == PluginStateStopping {
p.invokeHook("onunload")
}
}
// bindRpcMethod add or updates a JS function as a named RPC method.
func (p *KernelPlugin) bindRpcMethod(name string, method goja.Callable, descriptions ...string) error {
p.rpcMethods.Store(name, &RpcMethod{
Name: name,
Descriptions: descriptions,
Method: method,
})
return nil
}
// unbindRpcMethod removes a registered RPC method
func (p *KernelPlugin) unbindRpcMethod(name string) error {
_, ok := p.rpcMethods.LoadAndDelete(name)
if !ok {
return nil
}
return nil
}
// runtimeEventHandler dispatches an event to the plugin's goja runtime
func (p *KernelPlugin) runtimeEventHandler(event any) {
p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
return dispatchEvent(p, rt, event)
}, nil)
}
// pluginEventHandler handles events sent to the plugin
func (p *KernelPlugin) pluginEventHandler(event any) {
logging.LogDebugf("[plugin:%s] receive event: %#v", p.Name, event)
}
// subscribeEventHandlers subscribes to plugin lifecycle and RPC events, dispatching them to the plugin's JS runtime.
func (p *KernelPlugin) subscribeEventHandlers() (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
lo.Must0(p.bus.Subscribe(EventBusTopicRuntime, p.runtimeEventHandler))
lo.Must0(p.bus.Subscribe(EventBusTopicPlugin, p.pluginEventHandler))
return
}
// unsubscribeEventHandlers unsubscribes from plugin events.
func (p *KernelPlugin) unsubscribeEventHandlers() (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
lo.Must0(p.bus.Unsubscribe(EventBusTopicRuntime, p.runtimeEventHandler))
lo.Must0(p.bus.Unsubscribe(EventBusTopicPlugin, p.pluginEventHandler))
return
}
// GetRpcMethodsInfo returns a list of registered RPC methods with their descriptions.
func (p *KernelPlugin) GetRpcMethodsInfo() (methods []*RpcMethodInfo) {
p.rpcMethods.Range(func(name any, value any) bool {
if method, ok := value.(*RpcMethod); ok {
methods = append(methods, &RpcMethodInfo{
Name: method.Name,
Descriptions: method.Descriptions,
})
}
return true
})
return
}
// BroadcastNotification sends a JSON-RPC 2.0 notification to all inbound RPC WebSocket clients.
func (p *KernelPlugin) BroadcastNotification(method string, params util.Optional[any]) {
notification := JsonRpcRequest{
JsonRpc: JsonRpcVersion,
Method: method,
Params: params,
}
data, err := json.Marshal(notification)
if err != nil {
logging.LogWarnf("[plugin:%s] broadcast marshal: %s", p.Name, err)
return
}
p.socketsMu.RLock()
conns := make([]*gws.Conn, 0, len(p.sockets))
for conn, isRpcConnection := range p.sockets {
if isRpcConnection {
conns = append(conns, conn)
}
}
p.socketsMu.RUnlock()
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
c := conn
payload := make([]byte, len(data))
copy(payload, data) // each conn needs its own copy; WriteAsync is async and all conns share the same source slice
c.WriteAsync(gws.OpcodeText, payload, func(writeErr error) {
defer wg.Done()
if writeErr != nil {
logging.LogWarnf("[plugin:%s] RPC WebSocket notification write failed: %s", p.Name, writeErr)
}
})
}
wg.Wait()
}
// dispatchRpcRequests dispatches multiple JSON-RPC requests concurrently.
// Returns responses in the same order as requests. Nil responses indicate notifications.
func (p *KernelPlugin) dispatchRpcRequests(requests []*JsonRpcProcessingRequest) []*JsonRpcProcessingResponse {
responses := make([]*JsonRpcProcessingResponse, len(requests))
var wg sync.WaitGroup
for i, request := range requests {
// For requests that failed JSON parsing or validation, return the error immediately without dispatching.
if request.Error != nil {
responses[i] = &JsonRpcProcessingResponse{Error: request.Error}
continue
}
// For notifications, dispatch without waiting for a response.
if request.Request.IsNotification() {
go func(request *JsonRpcRequest) {
p.dispatchRpcRequest(request)
}(request.Request)
responses[i] = nil
continue
}
if request.Request == nil {
responses[i] = nil
continue
}
// For normal requests, dispatch concurrently and collect responses.
wg.Add(1)
go func(index int, request *JsonRpcRequest) {
defer wg.Done()
responses[index] = p.dispatchRpcRequest(request)
}(i, request.Request)
}
wg.Wait()
return responses
}
// dispatchRpcRequest routes a single JSON-RPC request to the plugin's registered JS method.
// Returns nil for notifications (no ID field).
func (p *KernelPlugin) dispatchRpcRequest(request *JsonRpcRequest) *JsonRpcProcessingResponse {
// Validate request structure
if rpcError := request.Validate(); rpcError != nil {
// For notifications, return nil (no response).
if request.IsNotification() {
return nil
}
// For invalid requests, return error response.
return &JsonRpcProcessingResponse{
Error: &JsonRpcErrorResponse{
JsonRpc: JsonRpcVersion,
Error: rpcError,
ID: request.ID,
},
}
}
// For notifications, call the method without waiting for a response and return nil.
if request.IsNotification() {
go p.callRpcMethod(request.Method, request.Params.Value)
return nil
}
// For normal requests, call the method and return response or error.
rpcResult, rpcError := p.callRpcMethod(request.Method, request.Params.Value)
if rpcError == nil {
return &JsonRpcProcessingResponse{
Response: &JsonRpcRequestResponse{
JsonRpc: JsonRpcVersion,
Result: rpcResult,
ID: request.ID,
},
}
} else {
return &JsonRpcProcessingResponse{
Error: &JsonRpcErrorResponse{
JsonRpc: JsonRpcVersion,
Error: rpcError,
ID: request.ID,
},
}
}
}
// callRpcMethod invokes a registered JS RPC method via the event bus and awaits the response.
func (p *KernelPlugin) callRpcMethod(method string, params any) (rpcResult any, rpcError *JsonRpcError) {
defer func() {
if r := recover(); r != nil {
logging.LogDebugf("[plugin:%s] panic in RPC method [%s]: %v", p.Name, method, r)
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("goja panic in RPC method [%s]: %v", method, r),
}
}
}()
if p.State() != PluginStateRunning {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("plugin [%s] not running (state: [%s])", p.Name, p.State()),
}
return
}
value, ok := p.rpcMethods.Load(method)
if !ok {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeMethodNotFound,
Message: JsonRpcErrorMethodNotFound.Message,
Data: fmt.Sprintf("method [%s] not found in plugin [%s]", method, p.Name),
}
return
}
rpcMethod, ok := value.(*RpcMethod)
if !ok {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("invalid method type for [%s]", method),
}
return
}
done := make(chan *TaskResult, 1)
p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
rpcParams := []goja.Value{}
jsParams := rt.ToValue(params)
if isJsArray(rt, jsParams) {
// If params is an array, convert to []goja.Value for variadic JS function calls.
rt.ForOf(jsParams, func(cur goja.Value) bool {
rpcParams = append(rpcParams, cur)
return true
})
} else if !goja.IsUndefined(jsParams) && !goja.IsNull(jsParams) {
// If params is not an array but is defined, pass as single argument.
rpcParams = append(rpcParams, jsParams)
} else {
// If params is undefined or null, pass no arguments.
}
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
done <- result.TaskResult()
}, rt, true, rpcMethod.Method, rt.GlobalObject(), rpcParams...)
return
}, nil)
result := <-done
if result.err != nil {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("error invoking method %q: %v", method, result.err),
}
return
}
rpcResult = result.value
return
}
// TrackRpcSocket adds a RPC WebSocket connection to the plugin's tracked list.
func (p *KernelPlugin) TrackRpcSocket(conn *gws.Conn) {
if conn == nil {
return
}
p.socketsMu.Lock()
defer p.socketsMu.Unlock()
p.sockets[conn] = true
}
// UntrackRpcSocket removes a gws WebSocket connection from the plugin's tracked list.
func (p *KernelPlugin) UntrackRpcSocket(conn *gws.Conn) {
if conn == nil {
return
}
p.socketsMu.Lock()
defer p.socketsMu.Unlock()
delete(p.sockets, conn)
}
// invokeHook calls a lifecycle hook (e.g. onload) if it exists, awaiting if it returns a Promise.
func (p *KernelPlugin) invokeHook(name string) {
var err error
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during lifecycle hook invocation: %v", r)
}
if err != nil {
logging.LogErrorf("[plugin:%s] lifecycle hook [%q] error: %v", p.Name, name, err)
}
}()
done := make(chan TaskResult, 1)
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
lifecycle, err := getJsContextValue(rt, []any{"siyuan", "plugin", "lifecycle"})
if err != nil {
return
}
if lifecycle == nil {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle not found")
return
}
pluginObj := lifecycle.ToObject(rt)
if pluginObj == nil {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle is not an object")
return
}
hookValue := pluginObj.Get(name)
hook, ok := goja.AssertFunction(hookValue)
if !ok {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle.%s not bound to a function", name)
return
}
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
done <- *result.TaskResult()
}, rt, true, hook, lifecycle)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
done <- TaskResult{err: err}
}
})
if runErr != nil {
done <- TaskResult{err: runErr}
}
result := <-done
if result.err != nil {
err = result.err
}
}
// handleHttpRequest dispatches an HTTP request to the plugin's JS handler and returns the response.
func (p *KernelPlugin) handleHttpRequest(c *gin.Context, request *Request, scope AccessScope) (response *HttpResponse, err error) {
type handleResult FunctionResult[*HttpResponse]
done := make(chan *handleResult, 1)
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeHTTP)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
invokeFunction(func(rt *goja.Runtime, result *CallResult) {
if result.Error != nil {
done <- &handleResult{Error: result.Error}
return
}
responseObj := result.Value.ToObject(rt)
if responseObj == nil {
done <- &handleResult{Error: fmt.Errorf("handler did not return an object")}
return
}
// convert response.body?.raw?.data from (string | Buffer | ArrayBuffer) to []byte
var raw *[]byte
if bodyValue := responseObj.Get("body"); isJsValueNotNull(bodyValue) {
// response.body
if bodyObj := bodyValue.ToObject(rt); bodyObj != nil {
if rawValue := bodyObj.Get("raw"); isJsValueNotNull(rawValue) {
// response.body.raw
if rawObj := rawValue.ToObject(rt); rawObj != nil {
if dataValue := rawObj.Get("data"); isJsValueNotNull(dataValue) {
// response.body.raw.data
dataBytes, convertErr := jsValueToBytes(rt, dataValue)
if convertErr == nil {
raw = &dataBytes
rawObj.Set("data", goja.Null())
}
}
}
}
}
}
// ❌ panic: invalid memory address or nil pointer dereference
// response := HttpResponse{}
// if err := rt.ExportTo(responseObj, &response); err != nil {
// done <- &ServerHandlerResult{Error: fmt.Errorf("invalid response format: %v", err)}
// return
// }
resultJson, marshalErr := responseObj.MarshalJSON()
if marshalErr != nil {
done <- &handleResult{Error: marshalErr}
return
}
response := HttpResponse{}
if unmarshalErr := json.Unmarshal(resultJson, &response); unmarshalErr != nil {
done <- &handleResult{Error: fmt.Errorf("invalid response format: %v", unmarshalErr)}
return
}
if raw != nil && response.Body != nil && response.Body.Raw != nil {
response.Body.Raw.Data = *raw
}
done <- &handleResult{Value: &response}
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
done <- &handleResult{Error: err}
}
})
if runErr != nil {
done <- &handleResult{Error: runErr}
}
select {
case result := <-done:
if result.Error != nil {
err = result.Error
} else {
response = result.Value
}
case <-c.Request.Context().Done():
// err = c.Request.Context().Err()
}
return
}
func (p *KernelPlugin) handleWebSocketRequest(c *gin.Context, request *Request, scope AccessScope) (err error) {
done := make(chan error, 1)
h := &WsEventHandler{p: p}
upgrader := gws.NewUpgrader(h, &gws.ServerOption{})
socket, upgradeErr := upgrader.Upgrade(c.Writer, c.Request)
if upgradeErr != nil {
err = upgradeErr
return
}
ctx, cancel := context.WithCancel(p.context)
var openOnce sync.Once
var closeOnce sync.Once
doOpen := func() {
go openOnce.Do(func() {
socket.ReadLoop()
cancel()
})
}
doClose := func() {
closeOnce.Do(func() {
socket.NetConn().Close()
cancel()
})
}
defer doClose()
var readyState atomic.Int64
var bufferedAmount atomic.Int64
readyState.Store(int64(WebSocketReadyStateConnecting))
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeWS)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
jsRequestObj := jsRequest.ToObject(rt)
if jsRequestObj == nil {
err = fmt.Errorf("failed to convert request value to object")
return
}
port := rt.NewObject()
invokePortHook := func(_ *goja.Runtime, name string, args ...goja.Value) {
hook := port.Get(name)
if fn, ok := goja.AssertFunction(hook); ok {
if _, callErr := fn(port, args...); callErr != nil {
logging.LogErrorf("[plugin:%s] ws server port hook %q: %v", p.Name, name, callErr)
}
}
}
setProtocol := func(rt *goja.Runtime, protocol string) {
port.Set("protocol", rt.ToValue(protocol))
}
setPortReadyState := func(rt *goja.Runtime, state WebSocketState) {
readyState.Store(int64(state))
port.Set("readyState", rt.ToValue(state))
}
updatePortBufferedAmount := func(rt *goja.Runtime, delta int) {
bufferedAmount.Add(int64(delta))
port.Set("bufferedAmount", rt.ToValue(bufferedAmount.Load()))
}
manager := &WsManager{
BufferedAmount: &bufferedAmount,
InvokeHook: invokePortHook,
SetProtocol: setProtocol,
SetReadyState: setPortReadyState,
}
h.BindOnOpen(manager)
h.BindOnClose(manager)
h.BindOnPing(manager)
h.BindOnPong(manager)
h.BindOnMessage(manager)
port_open := rt.ToValue(func(openCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
openPromise, openResolve, openReject := rt.NewPromise()
openRunErr := p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
doOpen()
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := openResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := openReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open reject: %v", p.Name, rejectErr)
}
}
})
if openRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open worker run: %v", p.Name, openRunErr)
}
return rt.ToValue(openPromise)
})
port_send := rt.ToValue(func(sendCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
sendPromise, sendResolve, sendReject := rt.NewPromise()
sendRunErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
var messageData []byte
var opcode gws.Opcode
if len(sendCall.Arguments) >= 1 {
data := sendCall.Argument(0)
if arrayBuffer, ok := data.Export().(goja.ArrayBuffer); ok {
opcode = gws.OpcodeBinary
b := arrayBuffer.Bytes()
messageData = make([]byte, len(b))
copy(messageData, b) // ArrayBuffer.Bytes() points into JS engine memory; copy before async send
} else {
opcode = gws.OpcodeText
messageData = []byte(data.String())
}
}
state := WebSocketState(readyState.Load())
if state == WebSocketReadyStateClosing || state == WebSocketReadyStateClosed {
err = fmt.Errorf("WebSocket is not open (state: %d)", state)
return
}
updatePortBufferedAmount(rt, len(messageData))
socket.WriteAsync(opcode, messageData, func(writeErr error) {
p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
if writeErr == nil {
updatePortBufferedAmount(rt, -len(messageData))
} else {
err = writeErr
}
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := sendResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := sendReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send reject: %v", p.Name, rejectErr)
}
}
})
})
return
}, func(rt *goja.Runtime, _ any, err error) {
if !lo.IsNil(err) {
if rejectErr := sendReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send reject: %v", p.Name, rejectErr)
}
}
})
if sendRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send worker run: %v", p.Name, sendRunErr)
}
return rt.ToValue(sendPromise)
})
port_ping := rt.ToValue(func(pingCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
pingPromise, pingResolve, pingReject := rt.NewPromise()
pingRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
var pingData string
if len(pingCall.Arguments) > 0 && !goja.IsUndefined(pingCall.Argument(0)) {
pingData = pingCall.Argument(0).String()
}
err = socket.WritePing([]byte(pingData))
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := pingResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := pingReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping reject: %v", p.Name, rejectErr)
}
}
})
if pingRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping worker run: %v", p.Name, pingRunErr)
}
return rt.ToValue(pingPromise)
})
port_pong := rt.ToValue(func(pongCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
pongPromise, pongResolve, pongReject := rt.NewPromise()
pongRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
var pongData string
if len(pongCall.Arguments) > 0 && !goja.IsUndefined(pongCall.Argument(0)) {
pongData = pongCall.Argument(0).String()
}
err = socket.WritePong([]byte(pongData))
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := pongResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := pongReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong reject: %v", p.Name, rejectErr)
}
}
})
if pongRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong worker run: %v", p.Name, pongRunErr)
}
return rt.ToValue(pongPromise)
})
port_close := rt.ToValue(func(closeCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
closePromise, closeResolve, closeReject := rt.NewPromise()
closeRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
code := uint16(1000)
var reason []byte
if isJsValueNotNull(closeCall.Argument(0)) {
code = uint16(closeCall.Argument(0).ToInteger())
}
if isJsValueNotNull(closeCall.Argument(1)) {
reason = []byte(closeCall.Argument(1).String())
}
setPortReadyState(rt, WebSocketReadyStateClosing)
err = socket.WriteClose(code, reason)
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := closeResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := closeReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close reject: %v", p.Name, rejectErr)
}
}
})
if closeRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close worker run: %v", p.Name, closeRunErr)
}
return rt.ToValue(closePromise)
})
lo.Must0(port.Set("binaryType", rt.ToValue("arraybuffer")))
lo.Must0(port.Set("bufferedAmount", rt.ToValue(bufferedAmount.Load())))
lo.Must0(port.Set("protocol", rt.ToValue("")))
lo.Must0(port.Set("readyState", rt.ToValue(readyState.Load())))
lo.Must0(port.Set("onopen", goja.Null()))
lo.Must0(port.Set("onmessage", goja.Null()))
lo.Must0(port.Set("onping", goja.Null()))
lo.Must0(port.Set("onpong", goja.Null()))
lo.Must0(port.Set("onclose", goja.Null()))
lo.Must0(port.Set("onerror", goja.Null()))
lo.Must0(port.Set("open", port_open))
lo.Must0(port.Set("send", port_send))
lo.Must0(port.Set("ping", port_ping))
lo.Must0(port.Set("pong", port_pong))
lo.Must0(port.Set("close", port_close))
lo.Must0(ObjectSeal(rt, port))
lo.Must0(jsRequestObj.Set("port", port))
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
if result.Error != nil {
// If the handler throws an error, close the connection and return the error.
select {
case done <- result.Error:
case <-ctx.Done():
}
return
}
// Auto-open if the handler did not call port.open() explicitly.
doOpen()
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
// If there is an error during the worker run (e.g. runtime panic), close the connection and return the error.
select {
case done <- err:
case <-ctx.Done():
}
}
})
if runErr != nil {
select {
case done <- runErr:
case <-ctx.Done():
}
}
select {
case err = <-done:
case <-ctx.Done():
}
return
}
// handleServerSentEventRequest dispatches an SSE request to the plugin's JS handler and streams events until completion or client disconnect.
func (p *KernelPlugin) handleServerSentEventRequest(c *gin.Context, request *Request, scope AccessScope) (err error) {
type sseEvent struct {
name string
message any
}
ctx, cancel := context.WithCancel(p.context)
var closeOnce sync.Once
doClose := func() {
closeOnce.Do(func() {
cancel()
})
}
defer doClose()
events := chanx.NewUnboundedChan[sseEvent](ctx, 16)
done := make(chan error, 1) // using to receive handler error or close signal
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeSSE)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
jsRequestObj := jsRequest.ToObject(rt)
if jsRequestObj == nil {
err = fmt.Errorf("failed to convert request value to object")
return
}
port := rt.NewObject()
invokePortHook := func(name string, args ...goja.Value) {
hook := port.Get(name)
if fn, ok := goja.AssertFunction(hook); ok {
if _, callErr := fn(port, args...); callErr != nil {
logging.LogErrorf("[plugin:%s] sse server port hook %q: %v", p.Name, name, callErr)
}
}
}
port_send := rt.ToValue(func(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
name := call.Argument(0).String()
message := call.Argument(1).Export()
events.In <- sseEvent{name, message}
return goja.Undefined()
})
port_close := rt.ToValue(func(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
doClose()
return goja.Undefined()
})
lo.Must0(port.Set("onopen", goja.Null()))
lo.Must0(port.Set("onclose", goja.Null()))
lo.Must0(port.Set("send", port_send))
lo.Must0(port.Set("close", port_close))
lo.Must0(ObjectSeal(rt, port))
lo.Must0(jsRequestObj.Set("port", port))
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
if result.Error != nil {
select {
case done <- result.Error:
case <-ctx.Done():
}
return
}
// Only start the onclose goroutine when the handler succeeds, so that
// onclose is never dispatched without a prior onopen.
go func() {
<-ctx.Done()
p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
event := rt.NewObject()
event.Set("type", rt.ToValue("close"))
invokePortHook("onclose", event)
return
}, nil)
}()
// Fire onopen first; signal done only after onopen has executed so that
// any port.send() calls inside onopen are enqueued before streaming begins.
p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
event := rt.NewObject()
event.Set("type", rt.ToValue("open"))
invokePortHook("onopen", event)
return
}, func(_ *goja.Runtime, _ any, err error) {
select {
case done <- err:
case <-ctx.Done():
}
})
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
select {
case done <- err:
case <-ctx.Done():
}
}
})
if runErr != nil {
select {
case done <- runErr:
case <-ctx.Done():
}
}
for {
select {
case e := <-events.Out:
c.SSEvent(e.name, e.message)
c.Writer.Flush()
case <-ctx.Done():
return
case <-c.Request.Context().Done():
// err = c.Request.Context().Err()
return
case handlerErr := <-done:
if handlerErr != nil {
err = handlerErr
return
}
// Handler completed successfully; keep streaming until port.close() or client disconnect.
}
}
}