Files
siyuan/kernel/plugin/plugin.go
Yingyi / 颖逸 0502614d59 🎨 Add MCP tool management API for kernel plugin (#17834)
* 🎨 Enhance MCP tool management with registration, unregistration, and sanitization features

* 🎨 Enhance MCP tool management with registration, unregistration, and sanitization features

* 🎨 Improve registerTool function with enhanced argument validation and output schema support

* 🎨 Enhance error handling and output schema in MCP tool registration

* 🎨 Add state check in invokeMcpTool to prevent invocation when plugin is not running
2026-06-09 20:21:25 +08:00

1428 lines
40 KiB
Go

// SiYuan - Refactor your thinking
// Copyright (c) 2020-present, b3log.org
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package plugin
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"sync/atomic"
"github.com/asaskevich/EventBus"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/eventloop"
"github.com/fsnotify/fsnotify"
"github.com/gin-contrib/sse"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/lxzan/gws"
"github.com/samber/lo"
"github.com/siyuan-note/logging"
"github.com/siyuan-note/siyuan/kernel/mcp/tools"
"github.com/siyuan-note/siyuan/kernel/model"
"github.com/siyuan-note/siyuan/kernel/util"
"github.com/smallnest/chanx"
)
type PluginState int64
type RpcMethod struct {
Name string
Descriptions []string
Method goja.Callable
}
type RpcMethodInfo struct {
Name string `json:"name"`
Descriptions []string `json:"descriptions"`
}
type R map[string]any
const (
EventBusTopicPlugin = "plugin" // Topic to kernel plugin
EventBusTopicRuntime = "runtime" // Topic to javascript runtime
)
const (
PluginStateReady PluginState = iota
PluginStateLoading
PluginStateRunning
PluginStateStopping
PluginStateStopped
PluginStateError
)
func (s PluginState) String() string {
switch s {
case PluginStateReady:
return "ready"
case PluginStateLoading:
return "loading"
case PluginStateRunning:
return "running"
case PluginStateStopping:
return "stopping"
case PluginStateStopped:
return "stopped"
case PluginStateError:
return "error"
default:
return "unknown"
}
}
// KernelPlugin represents a single kernel-side plugin instance.
type KernelPlugin struct {
*model.Petal
token string // JWT for this plugin
file string // kernel.js file path named in js runtime (e.g. "plugin-name/kernel.js")
pluginDir string // Base directory for this plugin (e.g. /path/to/workspace/data/plugins/plugin-name)
storageDir string // Base directory for this plugin's storage (e.g. /path/to/workspace/data/storage/petal/plugin-name)
worker Worker // Worker for serializing plugin js-call-go (e.g. logger) and go-call-js (e.g. RPC calls) tasks on a single goroutine
runtime *eventloop.EventLoop // goja event loop runtime for this plugin
watcher *fsnotify.Watcher // watcher for kernel plugin storage file changes
state atomic.Int64 // PluginState
context context.Context // Context for managing plugin lifecycle and cancellation
cancel context.CancelFunc // Cancel function for managing plugin lifecycle and cancellation
bus EventBus.Bus // Event bus for plugin events and RPC request/response dispatch
rpcMethods sync.Map // string -> *RpcMethod, registered JSON-RPC methods
mcpTools sync.Map // string -> *tools.Tool, fully-qualified MCP tool names registered by this plugin
socketsMu sync.RWMutex // mutex for gwsSockets map
sockets map[*gws.Conn]bool // tracked gws WebSocket connections (true: RPC server, false: regular)
}
func NewKernelPlugin(ctx context.Context, petal *model.Petal) *KernelPlugin {
token, err := model.CreatePluginJWT(petal.Name)
if err != nil {
logging.LogErrorf("Failed to create plugin JWT for [%s]: %v", petal.Name, err)
}
context, cancel := context.WithCancel(ctx)
watcher, err := fsnotify.NewWatcher()
if err != nil {
logging.LogErrorf("[plugin:%s] failed to create storage watcher: %s", petal.Name, err)
}
plugin := &KernelPlugin{
Petal: petal,
token: token,
file: fmt.Sprintf("%s/kernel.js", petal.Name),
pluginDir: filepath.Join(util.DataDir, "plugins", petal.Name),
storageDir: filepath.Join(util.DataDir, "storage", "petal", petal.Name),
watcher: watcher,
context: context,
cancel: cancel,
bus: EventBus.New(),
sockets: make(map[*gws.Conn]bool),
}
plugin.updateState(PluginStateReady)
return plugin
}
// createEventMessage creates a standardized event message with a unique ID, type, and detail payload.
func createEventMessage(eventType string, detail any) R {
return R{
"id": uuid.NewString(),
"type": eventType,
"detail": detail,
}
}
// State returns the current plugin state (safe for concurrent reads).
func (p *KernelPlugin) State() PluginState {
return PluginState(p.state.Load())
}
// Clear removes all registered MCP tools and RPC methods for this plugin.
// Called on plugin stop to prevent residue in global registries.
func (p *KernelPlugin) Clear() {
p.rpcMethods.Clear()
p.mcpTools.Range(func(_, value any) bool {
if tool, ok := value.(*tools.Tool); ok {
tools.RemoveTool(tool.Name)
}
return true
})
p.mcpTools.Clear()
}
// updateState updates the plugin state atomically and pushes the new state to the frontend via util.PushKernelPluginState.
func (p *KernelPlugin) updateState(state PluginState) {
p.state.Store(int64(state))
util.PushKernelPluginState(p.Name, int(state))
}
// InitRuntime initializes the goja runtime and evaluates kernel.js.
func (p *KernelPlugin) InitRuntime() (err error) {
p.runtime = eventloop.NewEventLoop(eventloop.EnableConsole(true))
p.worker.Start(p.runtime)
p.runtime.Run(func(rt *goja.Runtime) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("goja panic during event loop run: %v", r)
}
}()
// Use JSON struct tags for field name mapping, with fallback to original names if "json" tag is absent.
rt.SetFieldNameMapper(goja.TagFieldNameMapper("json", true))
if enableErr := EnableExtendModules(p, rt); enableErr != nil {
err = fmt.Errorf("EnableExtendModules: %v", enableErr)
return
}
if enableErr := EnableSiyuanModule(p, rt); enableErr != nil {
err = fmt.Errorf("EnableSiyuanModule: %v", enableErr)
return
}
if _, runErr := rt.RunScript(p.file, p.Kernel.JS); runErr != nil {
err = fmt.Errorf("RunScript: %v", runErr)
return
}
})
p.runtime.Start()
return
}
// Eval evaluates JavaScript code in the plugin's goja runtime, returning the result or error.
func (p *KernelPlugin) Eval(rt *goja.Runtime, code string) (goja.Value, error) {
return rt.RunScript(p.file, code)
}
// close interrupts the goja runtime and clears the pointer.
func (p *KernelPlugin) close() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("goja panic during close runtime: %v", r)
}
}()
if p.runtime != nil {
p.runtime.Stop() // Stops the event loop and waits for it to finish.
// p.runtime.Terminate() // Interrupts the runtime and causes all executing code to throw an exception.
}
return
}
// error sets the plugin state to errored and frees the goja runtime.
func (p *KernelPlugin) error() {
p.Clear()
if err := p.close(); err != nil {
logging.LogErrorf("[plugin:%s] failed to close runtime during error handling: %v", p.Name, err)
}
p.updateState(PluginStateError)
}
// start creates the goja runtime, injects sandbox globals, and evaluates kernel.js.
func (p *KernelPlugin) start() (err error) {
defer func() {
if r := recover(); r != nil {
p.error()
err = fmt.Errorf("goja panic during start: %v", r)
}
}()
p.updateState(PluginStateLoading)
baseDir := filepath.Join(util.DataDir, "storage", "petal", p.Name)
if err := os.MkdirAll(baseDir, 0755); err != nil {
return fmt.Errorf("create plugin dir [%s] failed: %s", baseDir, err)
}
if runtimeErr := p.InitRuntime(); runtimeErr != nil {
p.error()
return fmt.Errorf("start runtime: %v", runtimeErr)
}
if subscribeErr := p.subscribeEventHandlers(); subscribeErr != nil {
p.error()
return fmt.Errorf("subscribe plugin events: %v", subscribeErr)
}
go p.startStorageWatch()
p.onLoad()
p.updateState(PluginStateRunning)
p.onRunning()
p.bus.Publish(EventBusTopicRuntime, createEventMessage("start", nil))
logging.LogDebugf("[plugin:%s] started", p.Name)
return
}
// stop cleanly shuts down the plugin: closes sockets, frees goja runtime.
func (p *KernelPlugin) stop() (ok bool, err error) {
defer func() {
if r := recover(); r != nil {
p.error()
ok = false
err = fmt.Errorf("panic during plugin stop: %v", r)
}
}()
if p.State() != PluginStateRunning {
ok = false
return
}
p.bus.Publish(EventBusTopicRuntime, createEventMessage("stop", nil))
p.updateState(PluginStateStopping)
p.onUnload()
p.Clear()
p.cancel()
p.socketsMu.Lock()
for c := range p.sockets {
delete(p.sockets, c)
}
p.socketsMu.Unlock()
p.unsubscribeEventHandlers()
p.close()
p.updateState(PluginStateStopped)
logging.LogDebugf("[plugin:%s] stopped", p.Name)
ok = true
return
}
// onLoad is called before plugin loaded.
func (p *KernelPlugin) onLoad() {
if p.State() == PluginStateLoading {
p.invokeHook("onload")
}
}
// onRunning is called after plugin running.
func (p *KernelPlugin) onRunning() {
if p.State() == PluginStateRunning {
p.invokeHook("onrunning")
}
}
// onUnload is called before plugin stop.
func (p *KernelPlugin) onUnload() {
if p.State() == PluginStateStopping {
p.invokeHook("onunload")
}
}
// bindRpcMethod add or updates a JS function as a named RPC method.
func (p *KernelPlugin) bindRpcMethod(name string, method goja.Callable, descriptions ...string) error {
p.rpcMethods.Store(name, &RpcMethod{
Name: name,
Descriptions: descriptions,
Method: method,
})
return nil
}
// unbindRpcMethod removes a registered RPC method
func (p *KernelPlugin) unbindRpcMethod(name string) error {
_, ok := p.rpcMethods.LoadAndDelete(name)
if !ok {
return nil
}
return nil
}
// registerMcpTool registers a tool to the global MCP registry with a plugin-specific prefix, and tracks it for cleanup on plugin stop.
func (p *KernelPlugin) registerMcpTool(name string, tool *tools.Tool) error {
p.mcpTools.Store(name, tool)
tools.SetTool(tool.Name, tool)
return nil
}
// unregisterMcpTool removes a tool from the global MCP registry and the plugin's tracking map.
func (p *KernelPlugin) unregisterMcpTool(name string) error {
if value, loaded := p.mcpTools.LoadAndDelete(name); loaded {
if tool, ok := value.(*tools.Tool); ok {
tools.RemoveTool(tool.Name)
}
}
return nil
}
// invokeMcpTool calls a JS handler registered via siyuan.mcp.registerTool and returns the CallToolResult.
func (p *KernelPlugin) invokeMcpTool(handler goja.Callable, args map[string]interface{}) (tools.CallToolResult, error) {
if p.State() != PluginStateRunning {
return tools.CallToolResult{
IsError: true,
Content: []tools.ContentItem{{Type: "text", Text: fmt.Sprintf("plugin [%s] is not running (state: %s)", p.Name, p.State().String())}},
}, nil
}
done := make(chan *TaskResult, 1)
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
jsArgs := rt.ToValue(args)
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
done <- result.TaskResult()
}, rt, true, handler, rt.GlobalObject(), jsArgs)
return
}, nil)
if runErr != nil {
return tools.CallToolResult{
IsError: true,
Content: []tools.ContentItem{{Type: "text", Text: fmt.Sprintf("error running plugin runtime worker: %v", runErr)}},
}, nil
}
select {
case taskResult := <-done:
if taskResult.err != nil {
return tools.CallToolResult{
IsError: true,
Content: []tools.ContentItem{{Type: "text", Text: fmt.Sprintf("error invoking MCP tool handler: %v", taskResult.err)}},
}, nil
}
if taskResult.value == nil {
return tools.CallToolResult{
Content: []tools.ContentItem{{Type: "text", Text: "null"}},
}, nil
}
jsonBytes, marshalErr := json.Marshal(taskResult.value)
if marshalErr != nil {
return tools.CallToolResult{
IsError: true,
Content: []tools.ContentItem{{Type: "text", Text: fmt.Sprintf("error marshaling MCP tool result: %v", marshalErr)}},
}, nil
}
return tools.CallToolResult{
Content: []tools.ContentItem{{Type: "text", Text: string(jsonBytes)}},
}, nil
case <-p.context.Done():
return tools.CallToolResult{
IsError: true,
Content: []tools.ContentItem{{Type: "text", Text: "plugin stopped while invoking MCP tool handler"}},
}, nil
}
}
// runtimeEventHandler dispatches an event to the plugin's goja runtime
func (p *KernelPlugin) runtimeEventHandler(event any) {
p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
return dispatchEvent(p, rt, event)
}, nil)
}
// pluginEventHandler handles events sent to the plugin
func (p *KernelPlugin) pluginEventHandler(event any) {
logging.LogDebugf("[plugin:%s] receive event: %#v", p.Name, event)
}
// subscribeEventHandlers subscribes to plugin lifecycle and RPC events, dispatching them to the plugin's JS runtime.
func (p *KernelPlugin) subscribeEventHandlers() (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
lo.Must0(p.bus.Subscribe(EventBusTopicRuntime, p.runtimeEventHandler))
lo.Must0(p.bus.Subscribe(EventBusTopicPlugin, p.pluginEventHandler))
return
}
// unsubscribeEventHandlers unsubscribes from plugin events.
func (p *KernelPlugin) unsubscribeEventHandlers() (err error) {
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
lo.Must0(p.bus.Unsubscribe(EventBusTopicRuntime, p.runtimeEventHandler))
lo.Must0(p.bus.Unsubscribe(EventBusTopicPlugin, p.pluginEventHandler))
return
}
// GetRpcMethodsInfo returns a list of registered RPC methods with their descriptions.
func (p *KernelPlugin) GetRpcMethodsInfo() (methods []*RpcMethodInfo) {
p.rpcMethods.Range(func(name any, value any) bool {
if method, ok := value.(*RpcMethod); ok {
methods = append(methods, &RpcMethodInfo{
Name: method.Name,
Descriptions: method.Descriptions,
})
}
return true
})
return
}
// BroadcastNotification sends a JSON-RPC 2.0 notification to all inbound RPC WebSocket clients.
func (p *KernelPlugin) BroadcastNotification(method string, params util.Optional[any]) {
notification := JsonRpcRequest{
JsonRpc: JsonRpcVersion,
Method: method,
Params: params,
}
data, err := json.Marshal(notification)
if err != nil {
logging.LogWarnf("[plugin:%s] broadcast marshal: %s", p.Name, err)
return
}
p.socketsMu.RLock()
conns := make([]*gws.Conn, 0, len(p.sockets))
for conn, isRpcConnection := range p.sockets {
if isRpcConnection {
conns = append(conns, conn)
}
}
p.socketsMu.RUnlock()
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
c := conn
payload := make([]byte, len(data))
copy(payload, data) // each conn needs its own copy; WriteAsync is async and all conns share the same source slice
c.WriteAsync(gws.OpcodeText, payload, func(writeErr error) {
defer wg.Done()
if writeErr != nil {
logging.LogWarnf("[plugin:%s] RPC WebSocket notification write failed: %s", p.Name, writeErr)
}
})
}
wg.Wait()
}
// dispatchRpcRequests dispatches multiple JSON-RPC requests concurrently.
// Returns responses in the same order as requests. Nil responses indicate notifications.
func (p *KernelPlugin) dispatchRpcRequests(requests []*JsonRpcProcessingRequest) []*JsonRpcProcessingResponse {
responses := make([]*JsonRpcProcessingResponse, len(requests))
var wg sync.WaitGroup
for i, request := range requests {
// For requests that failed JSON parsing or validation, return the error immediately without dispatching.
if request.Error != nil {
responses[i] = &JsonRpcProcessingResponse{Error: request.Error}
continue
}
// For notifications, dispatch without waiting for a response.
if request.Request.IsNotification() {
go func(request *JsonRpcRequest) {
p.dispatchRpcRequest(request)
}(request.Request)
responses[i] = nil
continue
}
if request.Request == nil {
responses[i] = nil
continue
}
// For normal requests, dispatch concurrently and collect responses.
wg.Add(1)
go func(index int, request *JsonRpcRequest) {
defer wg.Done()
responses[index] = p.dispatchRpcRequest(request)
}(i, request.Request)
}
wg.Wait()
return responses
}
// dispatchRpcRequest routes a single JSON-RPC request to the plugin's registered JS method.
// Returns nil for notifications (no ID field).
func (p *KernelPlugin) dispatchRpcRequest(request *JsonRpcRequest) *JsonRpcProcessingResponse {
// Validate request structure
if rpcError := request.Validate(); rpcError != nil {
// For notifications, return nil (no response).
if request.IsNotification() {
return nil
}
// For invalid requests, return error response.
return &JsonRpcProcessingResponse{
Error: &JsonRpcErrorResponse{
JsonRpc: JsonRpcVersion,
Error: rpcError,
ID: request.ID,
},
}
}
// For notifications, call the method without waiting for a response and return nil.
if request.IsNotification() {
go p.callRpcMethod(request.Method, request.Params.Value)
return nil
}
// For normal requests, call the method and return response or error.
rpcResult, rpcError := p.callRpcMethod(request.Method, request.Params.Value)
if rpcError == nil {
return &JsonRpcProcessingResponse{
Response: &JsonRpcRequestResponse{
JsonRpc: JsonRpcVersion,
Result: rpcResult,
ID: request.ID,
},
}
} else {
return &JsonRpcProcessingResponse{
Error: &JsonRpcErrorResponse{
JsonRpc: JsonRpcVersion,
Error: rpcError,
ID: request.ID,
},
}
}
}
// callRpcMethod invokes a registered JS RPC method via the event bus and awaits the response.
func (p *KernelPlugin) callRpcMethod(method string, params any) (rpcResult any, rpcError *JsonRpcError) {
defer func() {
if r := recover(); r != nil {
logging.LogDebugf("[plugin:%s] panic in RPC method [%s]: %v", p.Name, method, r)
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("goja panic in RPC method [%s]: %v", method, r),
}
}
}()
if p.State() != PluginStateRunning {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("plugin [%s] not running (state: [%s])", p.Name, p.State()),
}
return
}
value, ok := p.rpcMethods.Load(method)
if !ok {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeMethodNotFound,
Message: JsonRpcErrorMethodNotFound.Message,
Data: fmt.Sprintf("method [%s] not found in plugin [%s]", method, p.Name),
}
return
}
rpcMethod, ok := value.(*RpcMethod)
if !ok {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("invalid method type for [%s]", method),
}
return
}
done := make(chan *TaskResult, 1)
p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
rpcParams := []goja.Value{}
jsParams := rt.ToValue(params)
if isJsArray(rt, jsParams) {
// If params is an array, convert to []goja.Value for variadic JS function calls.
rt.ForOf(jsParams, func(cur goja.Value) bool {
rpcParams = append(rpcParams, cur)
return true
})
} else if !goja.IsUndefined(jsParams) && !goja.IsNull(jsParams) {
// If params is not an array but is defined, pass as single argument.
rpcParams = append(rpcParams, jsParams)
} else {
// If params is undefined or null, pass no arguments.
}
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
done <- result.TaskResult()
}, rt, true, rpcMethod.Method, rt.GlobalObject(), rpcParams...)
return
}, nil)
result := <-done
if result.err != nil {
rpcError = &JsonRpcError{
Code: JsonRpcErrorCodeInternalError,
Message: JsonRpcErrorInternalError.Message,
Data: fmt.Sprintf("error invoking method %q: %v", method, result.err),
}
return
}
rpcResult = result.value
return
}
// TrackRpcSocket adds a RPC WebSocket connection to the plugin's tracked list.
func (p *KernelPlugin) TrackRpcSocket(conn *gws.Conn) {
if conn == nil {
return
}
p.socketsMu.Lock()
defer p.socketsMu.Unlock()
p.sockets[conn] = true
}
// UntrackRpcSocket removes a gws WebSocket connection from the plugin's tracked list.
func (p *KernelPlugin) UntrackRpcSocket(conn *gws.Conn) {
if conn == nil {
return
}
p.socketsMu.Lock()
defer p.socketsMu.Unlock()
delete(p.sockets, conn)
}
// startStorageWatch starts a goroutine to watch for file changes in the plugin's storage directory and dispatches events to the plugin.
func (p *KernelPlugin) startStorageWatch() {
if p.watcher == nil {
return
}
defer p.watcher.Close()
for {
select {
case <-p.context.Done():
return
case event, ok := <-p.watcher.Events:
if !ok {
return
}
switch event.Op {
case fsnotify.Create, fsnotify.Write, fsnotify.Rename, fsnotify.Remove:
path, relErr := filepath.Rel(p.storageDir, event.Name)
if relErr != nil {
logging.LogErrorf("[plugin:%s] failed to get relative storage path for [%s]: %v", p.Name, event.Name, relErr)
return
}
p.bus.Publish(EventBusTopicRuntime, createEventMessage("fs-notify", R{
"operation": event.Op.String(),
"path": path,
}))
}
case err, ok := <-p.watcher.Errors:
if !ok {
return
}
logging.LogErrorf("[plugin:%s] storage watcher error: %s", p.Name, err)
}
}
}
// addStorageWatch adds a path to the fsnotify watcher to watch for storage file/directory changes.
func (p *KernelPlugin) addStorageWatch(path string) (err error) {
if p.watcher == nil {
err = fmt.Errorf("fsnotify watcher not initialized")
return
}
err = p.watcher.Add(path)
return
}
// removeStorageWatch removes a path from the fsnotify watcher to stop watching for storage file/directory changes.
func (p *KernelPlugin) removeStorageWatch(path string) (err error) {
if p.watcher == nil {
err = fmt.Errorf("fsnotify watcher not initialized")
return
}
err = p.watcher.Remove(path)
return
}
// invokeHook calls a lifecycle hook (e.g. onload) if it exists, awaiting if it returns a Promise.
func (p *KernelPlugin) invokeHook(name string) {
var err error
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during lifecycle hook invocation: %v", r)
}
if err != nil {
logging.LogErrorf("[plugin:%s] lifecycle hook [%q] error: %v", p.Name, name, err)
}
}()
done := make(chan TaskResult, 1)
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
lifecycle, err := getJsContextValue(rt, []any{"siyuan", "plugin", "lifecycle"})
if err != nil {
return
}
if lifecycle == nil {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle not found")
return
}
pluginObj := lifecycle.ToObject(rt)
if pluginObj == nil {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle is not an object")
return
}
hookValue := pluginObj.Get(name)
hook, ok := goja.AssertFunction(hookValue)
if !ok {
err = fmt.Errorf("globalThis.siyuan.plugin.lifecycle.%s not bound to a function", name)
return
}
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
done <- *result.TaskResult()
}, rt, true, hook, lifecycle)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
done <- TaskResult{err: err}
}
})
if runErr != nil {
done <- TaskResult{err: runErr}
}
result := <-done
if result.err != nil {
err = result.err
}
}
// handleHttpRequest dispatches an HTTP request to the plugin's JS handler and returns the response.
func (p *KernelPlugin) handleHttpRequest(c *gin.Context, request *Request, scope AccessScope) (response *HttpResponse, err error) {
type handleResult FunctionResult[*HttpResponse]
done := make(chan *handleResult, 1)
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeHTTP)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
invokeFunction(func(rt *goja.Runtime, result *CallResult) {
if result.Error != nil {
done <- &handleResult{Error: result.Error}
return
}
responseObj := result.Value.ToObject(rt)
if responseObj == nil {
done <- &handleResult{Error: fmt.Errorf("handler did not return an object")}
return
}
// convert response.body?.raw?.data from (string | Buffer | ArrayBuffer) to []byte
var raw *[]byte
if bodyValue := responseObj.Get("body"); isJsValueNotNull(bodyValue) {
// response.body
if bodyObj := bodyValue.ToObject(rt); bodyObj != nil {
if rawValue := bodyObj.Get("raw"); isJsValueNotNull(rawValue) {
// response.body.raw
if rawObj := rawValue.ToObject(rt); rawObj != nil {
if dataValue := rawObj.Get("data"); isJsValueNotNull(dataValue) {
// response.body.raw.data
dataBytes, convertErr := jsValueToBytes(rt, dataValue)
if convertErr == nil {
raw = &dataBytes
rawObj.Set("data", goja.Null())
}
}
}
}
}
}
// ❌ panic: invalid memory address or nil pointer dereference
// response := HttpResponse{}
// if err := rt.ExportTo(responseObj, &response); err != nil {
// done <- &ServerHandlerResult{Error: fmt.Errorf("invalid response format: %v", err)}
// return
// }
resultJson, marshalErr := responseObj.MarshalJSON()
if marshalErr != nil {
done <- &handleResult{Error: marshalErr}
return
}
response := HttpResponse{}
if unmarshalErr := json.Unmarshal(resultJson, &response); unmarshalErr != nil {
done <- &handleResult{Error: fmt.Errorf("invalid response format: %v", unmarshalErr)}
return
}
if raw != nil && response.Body != nil && response.Body.Raw != nil {
response.Body.Raw.Data = *raw
}
done <- &handleResult{Value: &response}
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
done <- &handleResult{Error: err}
}
})
if runErr != nil {
done <- &handleResult{Error: runErr}
}
select {
case result := <-done:
if result.Error != nil {
err = result.Error
} else {
response = result.Value
}
case <-c.Request.Context().Done():
// err = c.Request.Context().Err()
}
return
}
func (p *KernelPlugin) handleWebSocketRequest(c *gin.Context, request *Request, scope AccessScope) (err error) {
done := make(chan error, 1)
h := &WsEventHandler{p: p}
upgrader := gws.NewUpgrader(h, &gws.ServerOption{})
socket, upgradeErr := upgrader.Upgrade(c.Writer, c.Request)
if upgradeErr != nil {
err = upgradeErr
return
}
ctx, cancel := context.WithCancel(p.context)
var openOnce sync.Once
var closeOnce sync.Once
doOpen := func() {
go openOnce.Do(func() {
socket.ReadLoop()
cancel()
})
}
doClose := func() {
closeOnce.Do(func() {
socket.NetConn().Close()
cancel()
})
}
defer doClose()
var readyState atomic.Int64
var bufferedAmount atomic.Int64
readyState.Store(int64(WebSocketReadyStateConnecting))
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeWS)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
jsRequestObj := jsRequest.ToObject(rt)
if jsRequestObj == nil {
err = fmt.Errorf("failed to convert request value to object")
return
}
port := rt.NewObject()
invokePortHook := func(_ *goja.Runtime, name string, args ...goja.Value) {
hook := port.Get(name)
if fn, ok := goja.AssertFunction(hook); ok {
if _, callErr := fn(port, args...); callErr != nil {
logging.LogErrorf("[plugin:%s] ws server port hook %q: %v", p.Name, name, callErr)
}
}
}
setProtocol := func(rt *goja.Runtime, protocol string) {
port.Set("protocol", rt.ToValue(protocol))
}
setPortReadyState := func(rt *goja.Runtime, state WebSocketState) {
readyState.Store(int64(state))
port.Set("readyState", rt.ToValue(state))
}
updatePortBufferedAmount := func(rt *goja.Runtime, delta int) {
bufferedAmount.Add(int64(delta))
port.Set("bufferedAmount", rt.ToValue(bufferedAmount.Load()))
}
manager := &WsManager{
BufferedAmount: &bufferedAmount,
InvokeHook: invokePortHook,
SetProtocol: setProtocol,
SetReadyState: setPortReadyState,
}
h.BindOnOpen(manager)
h.BindOnClose(manager)
h.BindOnPing(manager)
h.BindOnPong(manager)
h.BindOnMessage(manager)
port_open := rt.ToValue(func(openCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
openPromise, openResolve, openReject := rt.NewPromise()
openRunErr := p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
doOpen()
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := openResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := openReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open reject: %v", p.Name, rejectErr)
}
}
})
if openRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.open worker run: %v", p.Name, openRunErr)
}
return rt.ToValue(openPromise)
})
port_send := rt.ToValue(func(sendCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
sendPromise, sendResolve, sendReject := rt.NewPromise()
sendRunErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
var messageData []byte
var opcode gws.Opcode
if len(sendCall.Arguments) >= 1 {
data := sendCall.Argument(0)
if arrayBuffer, ok := data.Export().(goja.ArrayBuffer); ok {
opcode = gws.OpcodeBinary
b := arrayBuffer.Bytes()
messageData = make([]byte, len(b))
copy(messageData, b) // ArrayBuffer.Bytes() points into JS engine memory; copy before async send
} else {
opcode = gws.OpcodeText
messageData = []byte(data.String())
}
}
state := WebSocketState(readyState.Load())
if state == WebSocketReadyStateClosing || state == WebSocketReadyStateClosed {
err = fmt.Errorf("WebSocket is not open (state: %d)", state)
return
}
updatePortBufferedAmount(rt, len(messageData))
socket.WriteAsync(opcode, messageData, func(writeErr error) {
p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
if writeErr == nil {
updatePortBufferedAmount(rt, -len(messageData))
} else {
err = writeErr
}
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := sendResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := sendReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send reject: %v", p.Name, rejectErr)
}
}
})
})
return
}, func(rt *goja.Runtime, _ any, err error) {
if !lo.IsNil(err) {
if rejectErr := sendReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send reject: %v", p.Name, rejectErr)
}
}
})
if sendRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.send worker run: %v", p.Name, sendRunErr)
}
return rt.ToValue(sendPromise)
})
port_ping := rt.ToValue(func(pingCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
pingPromise, pingResolve, pingReject := rt.NewPromise()
pingRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
var pingData string
if len(pingCall.Arguments) > 0 && !goja.IsUndefined(pingCall.Argument(0)) {
pingData = pingCall.Argument(0).String()
}
err = socket.WritePing([]byte(pingData))
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := pingResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := pingReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping reject: %v", p.Name, rejectErr)
}
}
})
if pingRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.ping worker run: %v", p.Name, pingRunErr)
}
return rt.ToValue(pingPromise)
})
port_pong := rt.ToValue(func(pongCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
pongPromise, pongResolve, pongReject := rt.NewPromise()
pongRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
var pongData string
if len(pongCall.Arguments) > 0 && !goja.IsUndefined(pongCall.Argument(0)) {
pongData = pongCall.Argument(0).String()
}
err = socket.WritePong([]byte(pongData))
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := pongResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := pongReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong reject: %v", p.Name, rejectErr)
}
}
})
if pongRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.pong worker run: %v", p.Name, pongRunErr)
}
return rt.ToValue(pongPromise)
})
port_close := rt.ToValue(func(closeCall goja.FunctionCall, rt *goja.Runtime) goja.Value {
closePromise, closeResolve, closeReject := rt.NewPromise()
closeRunErr := p.worker.Run(func(rt *goja.Runtime) (result any, err error) {
code := uint16(1000)
var reason []byte
if isJsValueNotNull(closeCall.Argument(0)) {
code = uint16(closeCall.Argument(0).ToInteger())
}
if isJsValueNotNull(closeCall.Argument(1)) {
reason = []byte(closeCall.Argument(1).String())
}
setPortReadyState(rt, WebSocketReadyStateClosing)
err = socket.WriteClose(code, reason)
return
}, func(rt *goja.Runtime, result any, err error) {
if lo.IsNil(err) {
if resolveErr := closeResolve(result); resolveErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close resolve: %v", p.Name, resolveErr)
}
} else {
if rejectErr := closeReject(rt.NewGoError(err)); rejectErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close reject: %v", p.Name, rejectErr)
}
}
})
if closeRunErr != nil {
logging.LogErrorf("[plugin:%s] ws server port.close worker run: %v", p.Name, closeRunErr)
}
return rt.ToValue(closePromise)
})
lo.Must0(port.Set("binaryType", rt.ToValue("arraybuffer")))
lo.Must0(port.Set("bufferedAmount", rt.ToValue(bufferedAmount.Load())))
lo.Must0(port.Set("protocol", rt.ToValue("")))
lo.Must0(port.Set("readyState", rt.ToValue(readyState.Load())))
lo.Must0(port.Set("onopen", goja.Null()))
lo.Must0(port.Set("onmessage", goja.Null()))
lo.Must0(port.Set("onping", goja.Null()))
lo.Must0(port.Set("onpong", goja.Null()))
lo.Must0(port.Set("onclose", goja.Null()))
lo.Must0(port.Set("onerror", goja.Null()))
lo.Must0(port.Set("open", port_open))
lo.Must0(port.Set("send", port_send))
lo.Must0(port.Set("ping", port_ping))
lo.Must0(port.Set("pong", port_pong))
lo.Must0(port.Set("close", port_close))
lo.Must0(ObjectSeal(rt, port))
lo.Must0(jsRequestObj.Set("port", port))
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
if result.Error != nil {
// If the handler throws an error, close the connection and return the error.
select {
case done <- result.Error:
case <-ctx.Done():
}
return
}
// Auto-open if the handler did not call port.open() explicitly.
doOpen()
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
// If there is an error during the worker run (e.g. runtime panic), close the connection and return the error.
select {
case done <- err:
case <-ctx.Done():
}
}
})
if runErr != nil {
select {
case done <- runErr:
case <-ctx.Done():
}
}
select {
case err = <-done:
case <-ctx.Done():
}
return
}
// handleServerSentEventRequest dispatches an SSE request to the plugin's JS handler and streams events until completion or client disconnect.
func (p *KernelPlugin) handleServerSentEventRequest(c *gin.Context, request *Request, scope AccessScope) (err error) {
ctx, cancel := context.WithCancel(p.context)
var closeOnce sync.Once
doClose := func() {
closeOnce.Do(func() {
cancel()
})
}
defer doClose()
events := chanx.NewUnboundedChan[sse.Event](ctx, 16)
done := make(chan error, 1) // using to receive handler error or close signal
runErr := p.worker.Run(func(rt *goja.Runtime) (_ any, err error) {
handler, handlerObj, getHandlerErr := getRequestHandler(rt, scope, RequestTypeSSE)
if getHandlerErr != nil {
err = getHandlerErr
return
}
jsRequest, convertErr := requestGoToJs(p, rt, request)
if convertErr != nil {
err = convertErr
return
}
jsRequestObj := jsRequest.ToObject(rt)
if jsRequestObj == nil {
err = fmt.Errorf("failed to convert request value to object")
return
}
port := rt.NewObject()
invokePortHook := func(name string, args ...goja.Value) {
hook := port.Get(name)
if fn, ok := goja.AssertFunction(hook); ok {
if _, callErr := fn(port, args...); callErr != nil {
logging.LogErrorf("[plugin:%s] sse server port hook %q: %v", p.Name, name, callErr)
}
}
}
port_send := rt.ToValue(func(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
if eventJs := call.Argument(0); isJsValueNotNull(eventJs) {
if eventObj := eventJs.ToObject(rt); eventObj != nil {
e := sse.Event{}
if data := eventObj.Get("data"); isJsValueNotUndefined(data) {
e.Data = data.Export()
} else {
panic(rt.NewGoError(fmt.Errorf("event.data is required")))
}
if event := eventObj.Get("event"); goja.IsString(event) {
e.Event = event.String()
}
if id := eventObj.Get("id"); goja.IsString(id) {
e.Id = id.String()
}
if retry := eventObj.Get("retry"); goja.IsNumber(retry) {
e.Retry = uint(retry.ToInteger())
}
events.In <- e
return goja.Undefined()
}
}
panic(rt.NewGoError(fmt.Errorf("invalid event object")))
})
port_close := rt.ToValue(func(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
doClose()
return goja.Undefined()
})
lo.Must0(port.Set("onopen", goja.Null()))
lo.Must0(port.Set("onclose", goja.Null()))
lo.Must0(port.Set("send", port_send))
lo.Must0(port.Set("close", port_close))
lo.Must0(ObjectSeal(rt, port))
lo.Must0(jsRequestObj.Set("port", port))
invokeFunction(func(_ *goja.Runtime, result *CallResult) {
if result.Error != nil {
select {
case done <- result.Error:
case <-ctx.Done():
}
return
}
// Only start the onclose goroutine when the handler succeeds, so that
// onclose is never dispatched without a prior onopen.
go func() {
<-ctx.Done()
p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
event := rt.NewObject()
event.Set("type", rt.ToValue("close"))
invokePortHook("onclose", event)
return
}, nil)
}()
// Fire onopen first; signal done only after onopen has executed so that
// any port.send() calls inside onopen are enqueued before streaming begins.
p.worker.Run(func(rt *goja.Runtime) (_ any, _ error) {
event := rt.NewObject()
event.Set("type", rt.ToValue("open"))
invokePortHook("onopen", event)
return
}, func(_ *goja.Runtime, _ any, err error) {
select {
case done <- err:
case <-ctx.Done():
}
})
}, rt, true, handler, handlerObj, jsRequest)
return
}, func(_ *goja.Runtime, _ any, err error) {
if err != nil {
select {
case done <- err:
case <-ctx.Done():
}
}
})
if runErr != nil {
select {
case done <- runErr:
case <-ctx.Done():
}
}
for {
select {
case e := <-events.Out:
// c.SSEvent(e.Event, e.Data)
c.Render(-1, e)
c.Writer.Flush()
case <-ctx.Done():
return
case <-c.Request.Context().Done():
// err = c.Request.Context().Err()
return
case handlerErr := <-done:
if handlerErr != nil {
err = handlerErr
return
}
// Handler completed successfully; keep streaming until port.close() or client disconnect.
}
}
}