From e4148d331d7e54ffb7e8d473bdb6f78f67f4e693 Mon Sep 17 00:00:00 2001 From: Adam <13007539+MrgSub@users.noreply.github.com> Date: Sat, 19 Jul 2025 19:07:46 -0700 Subject: [PATCH] refactor durable objects (#1764) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Refactor Agent Architecture with Driver/Agent Split ## Description This PR refactors the agent architecture by splitting the ZeroAgent class into two separate classes: ZeroDriver and ZeroAgent. The ZeroDriver handles mail operations and database interactions, while ZeroAgent focuses on chat functionality. This separation of concerns improves code organization and maintainability. Key changes: - Created a new ZeroDriver class to handle mail operations and database interactions - Modified ZeroAgent to focus on chat functionality - Updated RPC target from AgentRpcDO to DriverRpcDO - Fixed variable declarations from `let` to `const` in mail-display.tsx - Updated environment configuration in wrangler.jsonc to include the new ZeroDriver class ## Type of Change - [x] ⚡ Performance improvement - [x] 🎨 UI/UX improvement ## Areas Affected - [x] Email Integration (Gmail, IMAP, etc.) - [x] Data Storage/Management - [x] API Endpoints - [x] Development Workflow ## Testing Done - [x] Manual testing performed ## Security Considerations - [x] No sensitive data is exposed - [x] Authentication checks are in place ## Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in complex areas - [x] My changes generate no new warnings ## Additional Notes This architectural change improves separation of concerns and should make the codebase more maintainable. The ZeroDriver handles all mail-related operations while ZeroAgent focuses on chat functionality, creating a cleaner division of responsibilities. --- apps/mail/components/mail/mail-display.tsx | 4 +- apps/server/src/lib/server-utils.ts | 4 +- apps/server/src/main.ts | 4 +- apps/server/src/pipelines.effect.ts | 2 +- apps/server/src/routes/agent/index.ts | 891 +++++++++++---------- apps/server/src/routes/agent/mcp.ts | 2 +- apps/server/src/routes/agent/rpc.ts | 11 +- apps/server/src/routes/agent/tools.ts | 77 +- apps/server/wrangler.jsonc | 24 + 9 files changed, 527 insertions(+), 492 deletions(-) diff --git a/apps/mail/components/mail/mail-display.tsx b/apps/mail/components/mail/mail-display.tsx index df2413da1..77c735dee 100644 --- a/apps/mail/components/mail/mail-display.tsx +++ b/apps/mail/components/mail/mail-display.tsx @@ -386,7 +386,7 @@ const downloadAttachment = async (attachment: { attachmentId: string; }) => { try { - let attachmentData = attachment.body; + const attachmentData = attachment.body; if (!attachmentData) { throw new Error('Attachment data not found'); @@ -475,7 +475,7 @@ const openAttachment = async (attachment: { attachmentId: string; }) => { try { - let attachmentData = attachment.body; + const attachmentData = attachment.body; if (!attachmentData) { throw new Error('Attachment data not found'); diff --git a/apps/server/src/lib/server-utils.ts b/apps/server/src/lib/server-utils.ts index 61d6a2dfd..fe723c8f4 100644 --- a/apps/server/src/lib/server-utils.ts +++ b/apps/server/src/lib/server-utils.ts @@ -11,9 +11,9 @@ export const getZeroDB = (userId: string) => { }; export const getZeroAgent = async (connectionId: string) => { - const stub = env.ZERO_AGENT.get(env.ZERO_AGENT.idFromName(connectionId)); + const stub = env.ZERO_DRIVER.get(env.ZERO_DRIVER.idFromName(connectionId)); const rpcTarget = await stub.setMetaData(connectionId); - await rpcTarget.setupAuth(connectionId); + await rpcTarget.setupAuth(); return rpcTarget; }; diff --git a/apps/server/src/main.ts b/apps/server/src/main.ts index 697f6cf86..a1f68224d 100644 --- a/apps/server/src/main.ts +++ b/apps/server/src/main.ts @@ -25,13 +25,13 @@ import { defaultUserSettings } from './lib/schemas'; import { createLocalJWKSet, jwtVerify } from 'jose'; import { routePartykitRequest } from 'partyserver'; +import { ZeroAgent, ZeroDriver } from './routes/agent'; import { enableBrainFunction } from './lib/brain'; import { trpcServer } from '@hono/trpc-server'; import { agentsMiddleware } from 'hono-agents'; import { ZeroMCP } from './routes/agent/mcp'; import { publicRouter } from './routes/auth'; import { autumnApi } from './routes/autumn'; -import { ZeroAgent } from './routes/agent'; import type { HonoContext } from './ctx'; import { createDb, type DB } from './db'; import { createAuth } from './lib/auth'; @@ -789,4 +789,4 @@ export default class extends WorkerEntrypoint { } } -export { ZeroAgent, ZeroMCP, ZeroDB }; +export { ZeroAgent, ZeroMCP, ZeroDB, ZeroDriver }; diff --git a/apps/server/src/pipelines.effect.ts b/apps/server/src/pipelines.effect.ts index 78208bfa6..0d32313b4 100644 --- a/apps/server/src/pipelines.effect.ts +++ b/apps/server/src/pipelines.effect.ts @@ -25,9 +25,9 @@ import { import { defaultLabels, EPrompts, EProviders, type ParsedMessage, type Sender } from './types'; import { getZeroAgent } from './lib/server-utils'; import { type gmail_v1 } from '@googleapis/gmail'; -import { connection, summary } from './db/schema'; import { getPromptName } from './pipelines'; import { env } from 'cloudflare:workers'; +import { connection } from './db/schema'; import { Effect, Console } from 'effect'; import * as cheerio from 'cheerio'; import { eq } from 'drizzle-orm'; diff --git a/apps/server/src/routes/agent/index.ts b/apps/server/src/routes/agent/index.ts index 3d7bfcf80..08058036d 100644 --- a/apps/server/src/routes/agent/index.ts +++ b/apps/server/src/routes/agent/index.ts @@ -31,7 +31,6 @@ import { EPrompts, type IOutgoingMessage, type ParsedMessage } from '../../types import type { MailManager, IGetThreadResponse } from '../../lib/driver/types'; import { connectionToDriver } from '../../lib/server-utils'; import type { CreateDraftData } from '../../lib/schemas'; -import type { Connection, WSMessage } from 'partyserver'; import { withRetry } from '../../lib/gmail-rate-limit'; import { getPrompt } from '../../pipelines.effect'; import { AIChatAgent } from 'agents/ai-chat-agent'; @@ -40,11 +39,13 @@ import { AiChatPrompt } from '../../lib/prompts'; import { getPromptName } from '../../pipelines'; import { anthropic } from '@ai-sdk/anthropic'; import { connection } from '../../db/schema'; +import type { WSMessage } from 'partyserver'; import { tools as authTools } from './tools'; import { processToolCalls } from './utils'; import { env } from 'cloudflare:workers'; +import type { Connection } from 'agents'; import { createDb } from '../../db'; -import { AgentRpcDO } from './rpc'; +import { DriverRpcDO } from './rpc'; import { eq } from 'drizzle-orm'; import { Effect } from 'effect'; const decoder = new TextDecoder(); @@ -53,16 +54,20 @@ const shouldDropTables = env.DROP_AGENT_TABLES === 'true'; const maxCount = parseInt(env.THREAD_SYNC_MAX_COUNT || '10', 10); const shouldLoop = env.THREAD_SYNC_LOOP !== 'false'; -export class ZeroAgent extends AIChatAgent { - private chatMessageAbortControllers: Map = new Map(); +export const getZeroDriver = (connectionId: string) => + env.ZERO_DRIVER.get(env.ZERO_DRIVER.idFromName(connectionId)); + +export class ZeroDriver extends AIChatAgent { private foldersInSync: Map = new Map(); private syncThreadsInProgress: Map = new Map(); - private currentFolder: string | null = 'inbox'; - driver: MailManager | null = null; + private driver: MailManager | null = null; + private agent: DurableObjectStub = env.ZERO_AGENT.get( + env.ZERO_AGENT.idFromName(this.name), + ); constructor(ctx: DurableObjectState, env: Env) { super(ctx, env); if (shouldDropTables) this.dropTables(); - this.sql` + void this.sql` CREATE TABLE IF NOT EXISTS threads ( id TEXT PRIMARY KEY, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -78,310 +83,94 @@ export class ZeroAgent extends AIChatAgent { `; } - async dropTables() { - return this.sql` - DROP TABLE IF EXISTS threads;`; - } - async setMetaData(connectionId: string) { await this.setName(connectionId); - return new AgentRpcDO(this, connectionId); + return new DriverRpcDO(this, connectionId); } - async registerZeroMCP() { - await this.mcp.connect(env.VITE_PUBLIC_BACKEND_URL + '/sse', { - transport: { - authProvider: new DurableObjectOAuthClientProvider( - this.ctx.storage, - 'zero-mcp', - env.VITE_PUBLIC_BACKEND_URL, - ), - }, - }); + async markAsRead(threadIds: string[]) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.markAsRead(threadIds); } - onStart(): void | Promise { - // this.registerZeroMCP(); + async markAsUnread(threadIds: string[]) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.markAsUnread(threadIds); } - private getDataStreamResponse( - onFinish: StreamTextOnFinishCallback<{}>, - _?: { - abortSignal: AbortSignal | undefined; - }, - ) { - const dataStreamResponse = createDataStreamResponse({ - execute: async (dataStream) => { - const connectionId = this.name; - if (connectionId === 'general') return; - if (!connectionId || !this.driver) { - console.log('Unauthorized no driver or connectionId [1]', connectionId, this.driver); - await this.setupAuth(connectionId); - if (!connectionId || !this.driver) { - console.log('Unauthorized no driver or connectionId', connectionId, this.driver); - throw new Error('Unauthorized no driver or connectionId [2]'); - } - } - const orchestrator = new ToolOrchestrator(dataStream, connectionId); - // const mcpTools = await this.mcp.unstable_getAITools(); - - const rawTools = { - ...(await authTools(this, connectionId)), - }; - const tools = orchestrator.processTools({}); - const processedMessages = await processToolCalls( - { - messages: this.messages, - dataStream, - tools, - }, - {}, - ); - - const result = streamText({ - model: anthropic(env.OPENAI_MODEL || 'claude-3-5-haiku-latest'), - maxSteps: 10, - messages: processedMessages, - tools: rawTools, - onFinish, - onError: (error) => { - console.error('Error in streamText', error); - }, - system: await getPrompt(getPromptName(connectionId, EPrompts.Chat), AiChatPrompt('')), - }); - - result.mergeIntoDataStream(dataStream); - }, - }); - - return dataStreamResponse; + async normalizeIds(ids: string[]) { + if (!this.driver) { + throw new Error('No driver available'); + } + return this.driver.normalizeIds(ids); } - public async setupAuth(connectionId: string) { - if (connectionId === 'general') return; + async sendDraft(id: string, data: IOutgoingMessage) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.sendDraft(id, data); + } + + async create(data: IOutgoingMessage) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.create(data); + } + + async delete(id: string) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.delete(id); + } + + async deleteAllSpam() { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.deleteAllSpam(); + } + + async getEmailAliases() { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.getEmailAliases(); + } + + async getMessageAttachments(messageId: string) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.driver.getMessageAttachments(messageId); + } + + async onConnect() { + await this.setupAuth(); + } + + public async setupAuth() { + if (this.name === 'general') return; if (!this.driver) { const { db, conn } = createDb(env.HYPERDRIVE.connectionString); const _connection = await db.query.connection.findFirst({ - where: eq(connection.id, connectionId), + where: eq(connection.id, this.name), }); if (_connection) this.driver = connectionToDriver(_connection); this.ctx.waitUntil(conn.end()); this.ctx.waitUntil(this.syncThreads('inbox')); - this.ctx.waitUntil(this.syncThreads('sent')); - this.ctx.waitUntil(this.syncThreads('spam')); - this.ctx.waitUntil(this.syncThreads('archive')); + // this.ctx.waitUntil(this.syncThreads('sent')); + // this.ctx.waitUntil(this.syncThreads('spam')); + // this.ctx.waitUntil(this.syncThreads('archive')); } } - private async tryCatchChat(fn: () => T | Promise) { - try { - return await fn(); - } catch (e) { - throw this.onError(e); - } - } - - private getAbortSignal(id: string): AbortSignal | undefined { - // Defensive check, since we're coercing message types at the moment - if (typeof id !== 'string') { - return undefined; - } - - if (!this.chatMessageAbortControllers.has(id)) { - this.chatMessageAbortControllers.set(id, new AbortController()); - } - - return this.chatMessageAbortControllers.get(id)?.signal; - } - - /** - * Remove an abort controller from the cache of pending message responses - */ - private removeAbortController(id: string) { - this.chatMessageAbortControllers.delete(id); - } - - private broadcastChatMessage(message: OutgoingMessage, exclude?: string[]) { - this.broadcast(JSON.stringify(message), exclude); - } - - private cancelChatRequest(id: string) { - if (this.chatMessageAbortControllers.has(id)) { - const abortController = this.chatMessageAbortControllers.get(id); - abortController?.abort(); - } - } - - async onMessage(connection: Connection, message: WSMessage) { - if (typeof message === 'string') { - let data: IncomingMessage; - try { - data = JSON.parse(message) as IncomingMessage; - } catch (error) { - console.warn(error); - // silently ignore invalid messages for now - // TODO: log errors with log levels - return; - } - switch (data.type) { - case IncomingMessageType.UseChatRequest: { - if (data.init.method !== 'POST') break; - - const { body } = data.init; - - const { messages } = JSON.parse(body as string); - this.broadcastChatMessage( - { - type: OutgoingMessageType.ChatMessages, - messages, - }, - [connection.id], - ); - await this.persistMessages(messages, [connection.id]); - - const chatMessageId = data.id; - const abortSignal = this.getAbortSignal(chatMessageId); - - return this.tryCatchChat(async () => { - const response = await this.onChatMessage( - async ({ response }) => { - const finalMessages = appendResponseMessages({ - messages, - responseMessages: response.messages, - }); - - await this.persistMessages(finalMessages, [connection.id]); - this.removeAbortController(chatMessageId); - }, - abortSignal ? { abortSignal } : undefined, - ); - - if (response) { - await this.reply(data.id, response); - } else { - console.warn( - `[AIChatAgent] onChatMessage returned no response for chatMessageId: ${chatMessageId}`, - ); - this.broadcastChatMessage( - { - id: data.id, - type: OutgoingMessageType.UseChatResponse, - body: 'No response was generated by the agent.', - done: true, - }, - [connection.id], - ); - } - }); - } - case IncomingMessageType.ChatClear: { - this.destroyAbortControllers(); - this.sql`delete from cf_ai_chat_agent_messages`; - this.messages = []; - this.broadcastChatMessage( - { - type: OutgoingMessageType.ChatClear, - }, - [connection.id], - ); - break; - } - case IncomingMessageType.ChatMessages: { - await this.persistMessages(data.messages, [connection.id]); - break; - } - case IncomingMessageType.ChatRequestCancel: { - this.cancelChatRequest(data.id); - break; - } - // case IncomingMessageType.Mail_List: { - // const result = await this.getThreadsFromDB({ - // labelIds: data.labelIds, - // folder: data.folder, - // q: data.query, - // max: data.maxResults, - // cursor: data.pageToken, - // }); - // this.currentFolder = data.folder; - // connection.send( - // JSON.stringify({ - // type: OutgoingMessageType.Mail_List, - // result, - // }), - // ); - // break; - // } - // case IncomingMessageType.Mail_Get: { - // const result = await this.getThreadFromDB(data.threadId); - // connection.send( - // JSON.stringify({ - // type: OutgoingMessageType.Mail_Get, - // result, - // threadId: data.threadId, - // }), - // ); - // break; - // } - } - } - } - - private async reply(id: string, response: Response) { - // now take chunks out from dataStreamResponse and send them to the client - return this.tryCatchChat(async () => { - for await (const chunk of response.body!) { - const body = decoder.decode(chunk); - - this.broadcastChatMessage({ - id, - type: OutgoingMessageType.UseChatResponse, - body, - done: false, - }); - } - - this.broadcastChatMessage({ - id, - type: OutgoingMessageType.UseChatResponse, - body: '', - done: true, - }); - }); - } - - async onConnect() { - await this.setupAuth(this.name); - } - - private destroyAbortControllers() { - for (const controller of this.chatMessageAbortControllers.values()) { - controller?.abort(); - } - this.chatMessageAbortControllers.clear(); - } - - async onChatMessage( - onFinish: StreamTextOnFinishCallback<{}>, - options?: { - abortSignal: AbortSignal | undefined; - }, - ) { - return this.getDataStreamResponse(onFinish, options); - } - - async listThreads(params: { - folder: string; - query?: string; - maxResults?: number; - labelIds?: string[]; - pageToken?: string; - }) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.getThreadsFromDB(params); - } - async rawListThreads(params: { folder: string; query?: string; @@ -532,98 +321,55 @@ export class ZeroAgent extends AIChatAgent { return await this.driver.count(); } - async list(params: { - folder: string; - query?: string; - maxResults?: number; - labelIds?: string[]; - pageToken?: string; - }) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.getThreadsFromDB(params); + private async listWithRetry(params: Parameters[0]) { + if (!this.driver) throw new Error('No driver available'); + + return Effect.runPromise(withRetry(Effect.tryPromise(() => this.driver!.list(params)))); } - async markAsRead(threadIds: string[]) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.markAsRead(threadIds); + private async getWithRetry(threadId: string): Promise { + if (!this.driver) throw new Error('No driver available'); + + return Effect.runPromise(withRetry(Effect.tryPromise(() => this.driver!.get(threadId)))); } - async markAsUnread(threadIds: string[]) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.markAsUnread(threadIds); + private getThreadKey(threadId: string) { + return `${this.name}/${threadId}.json`; } - async normalizeIds(ids: string[]) { - if (!this.driver) { - throw new Error('No driver available'); + async *streamThreads(folder: string) { + let pageToken: string | null = null; + let hasMore = true; + + while (hasMore) { + // Rate limiting delay + await new Promise((resolve) => setTimeout(resolve, 2000)); + + const result = await this.listWithRetry({ + folder, + maxResults: maxCount, // Smaller batches for streaming + pageToken: pageToken || undefined, + }); + + // Stream each thread individually + for (const thread of result.threads) { + yield thread; + } + + pageToken = result.nextPageToken; + hasMore = pageToken !== null && shouldLoop; } - return this.driver.normalizeIds(ids); } - async get(id: string) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.getThreadFromDB(id); - } - - async sendDraft(id: string, data: IOutgoingMessage) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.sendDraft(id, data); - } - - async create(data: IOutgoingMessage) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.create(data); - } - - async delete(id: string) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.delete(id); - } - - async deleteAllSpam() { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.deleteAllSpam(); - } - - async getEmailAliases() { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.getEmailAliases(); - } - - async getMessageAttachments(messageId: string) { - if (!this.driver) { - throw new Error('No driver available'); - } - return await this.driver.getMessageAttachments(messageId); - } - - async getThreadCount() { - const count = this.sql`SELECT COUNT(*) FROM threads`; - return count[0]['COUNT(*)'] as number; + async dropTables() { + return this.sql` + DROP TABLE IF EXISTS threads;`; } async syncThread({ threadId }: { threadId: string }) { if (this.name === 'general') return; if (!this.driver) { - await this.setupAuth(this.name); + await this.setupAuth(); } if (!this.driver) { @@ -652,14 +398,14 @@ export class ZeroAgent extends AIChatAgent { }, }); - this.sql` + void this.sql` INSERT OR REPLACE INTO threads ( - id, - thread_id, - provider_id, - latest_sender, - latest_received_on, - latest_subject, + id, + thread_id, + provider_id, + latest_sender, + latest_received_on, + latest_subject, latest_label_ids, updated_at ) VALUES ( @@ -673,12 +419,10 @@ export class ZeroAgent extends AIChatAgent { CURRENT_TIMESTAMP ) `; - if (this.currentFolder === 'inbox') { - this.broadcastChatMessage({ - type: OutgoingMessageType.Mail_Get, - threadId, - }); - } + this.agent.broadcastChatMessage({ + type: OutgoingMessageType.Mail_Get, + threadId, + }); this.syncThreadsInProgress.delete(threadId); console.log('Server: syncThread result', { threadId, @@ -697,47 +441,9 @@ export class ZeroAgent extends AIChatAgent { } } - getThreadKey(threadId: string) { - return `${this.name}/${threadId}.json`; - } - - private async listWithRetry(params: Parameters[0]) { - if (!this.driver) throw new Error('No driver available'); - - return Effect.runPromise(withRetry(Effect.tryPromise(() => this.driver!.list(params)))); - } - - private async getWithRetry(threadId: string): Promise { - if (!this.driver) throw new Error('No driver available'); - - return Effect.runPromise(withRetry(Effect.tryPromise(() => this.driver!.get(threadId)))); - } - - async *streamThreads(folder: string) { - let pageToken: string | null = null; - let hasMore = true; - let _pageCount = 0; - - while (hasMore) { - _pageCount++; - - // Rate limiting delay - await new Promise((resolve) => setTimeout(resolve, 2000)); - - const result = await this.listWithRetry({ - folder, - maxResults: maxCount, // Smaller batches for streaming - pageToken: pageToken || undefined, - }); - - // Stream each thread individually - for (const thread of result.threads) { - yield thread; - } - - pageToken = result.nextPageToken; - hasMore = pageToken !== null && shouldLoop; - } + async getThreadCount() { + const count = this.sql`SELECT COUNT(*) FROM threads`; + return count[0]['COUNT(*)'] as number; } async syncThreads(folder: string) { @@ -773,7 +479,7 @@ export class ZeroAgent extends AIChatAgent { } // // Broadcast progress after each thread - // this.broadcastChatMessage({ + // this.agent.broadcastChatMessage({ // type: OutgoingMessageType.Mail_List, // folder, // }); @@ -786,7 +492,7 @@ export class ZeroAgent extends AIChatAgent { } finally { console.log('Setting isSyncing to false'); this.foldersInSync.delete(folder); - this.broadcastChatMessage({ + this.agent.broadcastChatMessage({ type: OutgoingMessageType.Mail_List, folder, }); @@ -923,7 +629,7 @@ export class ZeroAgent extends AIChatAgent { if (whereConditions.length === 0) { // No conditions - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads ORDER BY latest_received_on DESC @@ -934,7 +640,7 @@ export class ZeroAgent extends AIChatAgent { const condition = whereConditions[0]; if (condition.includes('latest_received_on <')) { const cursorValue = pageToken!; - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE latest_received_on < ${cursorValue} @@ -944,7 +650,7 @@ export class ZeroAgent extends AIChatAgent { } else if (folder) { // Folder condition const folderLabel = folder.toUpperCase(); - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE EXISTS ( @@ -956,7 +662,7 @@ export class ZeroAgent extends AIChatAgent { } else { // Single label condition const labelId = labelIds[0]; - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE EXISTS ( @@ -971,7 +677,7 @@ export class ZeroAgent extends AIChatAgent { if (folder && labelIds.length === 0 && pageToken) { // Folder + cursor const folderLabel = folder.toUpperCase(); - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE EXISTS ( @@ -983,7 +689,7 @@ export class ZeroAgent extends AIChatAgent { } else if (labelIds.length === 1 && pageToken && !folder) { // Single label + cursor const labelId = labelIds[0]; - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE EXISTS ( @@ -995,7 +701,7 @@ export class ZeroAgent extends AIChatAgent { } else { // For now, fallback to just cursor if complex combinations const cursorValue = pageToken || ''; - result = await this.sql` + result = this.sql` SELECT id, latest_received_on FROM threads WHERE latest_received_on < ${cursorValue} @@ -1005,20 +711,26 @@ export class ZeroAgent extends AIChatAgent { } } - const threads = result.map((row: any) => ({ - id: row.id, - historyId: null, - })); + if (result?.length) { + const threads = result.map((row) => ({ + id: row.id, + historyId: null, + })); - // Use latest_received_on for pagination cursor - const nextPageToken = - threads.length === maxResults && result.length > 0 - ? result[result.length - 1].latest_received_on - : null; + // Use latest_received_on for pagination cursor + const nextPageToken = + threads.length === maxResults && result.length > 0 + ? result[result.length - 1].latest_received_on + : null; + return { + threads, + nextPageToken, + }; + } return { - threads, - nextPageToken, + threads: [], + nextPageToken: '', }; } catch (error) { console.error('Failed to get threads from database:', error); @@ -1051,7 +763,7 @@ export class ZeroAgent extends AIChatAgent { await this.syncThread({ threadId: id }); return this.getThreadFromDB(id, true); } - const row = result[0] as any; + const row = result[0] as { latest_label_ids: string }; const storedThread = await env.THREADS_BUCKET.get(this.getThreadKey(id)); const messages: ParsedMessage[] = storedThread @@ -1072,4 +784,295 @@ export class ZeroAgent extends AIChatAgent { throw error; } } + + async listThreads(params: { + folder: string; + query?: string; + maxResults?: number; + labelIds?: string[]; + pageToken?: string; + }) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.getThreadsFromDB(params); + } + + async list(params: { + folder: string; + query?: string; + maxResults?: number; + labelIds?: string[]; + pageToken?: string; + }) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.getThreadsFromDB(params); + } + + async get(id: string) { + if (!this.driver) { + throw new Error('No driver available'); + } + return await this.getThreadFromDB(id); + } +} + +export class ZeroAgent extends AIChatAgent { + private chatMessageAbortControllers: Map = new Map(); + private driver: DurableObjectStub = getZeroDriver(this.name); + + async registerZeroMCP() { + await this.mcp.connect(env.VITE_PUBLIC_BACKEND_URL + '/sse', { + transport: { + authProvider: new DurableObjectOAuthClientProvider( + this.ctx.storage, + 'zero-mcp', + env.VITE_PUBLIC_BACKEND_URL, + ), + }, + }); + } + + onStart(): void | Promise { + // this.registerZeroMCP(); + } + + private getDataStreamResponse( + onFinish: StreamTextOnFinishCallback<{}>, + _?: { + abortSignal: AbortSignal | undefined; + }, + ) { + const dataStreamResponse = createDataStreamResponse({ + execute: async (dataStream) => { + if (this.name === 'general') return; + const connectionId = this.name; + const orchestrator = new ToolOrchestrator(dataStream, connectionId); + // const mcpTools = await this.mcp.unstable_getAITools(); + + const rawTools = { + ...(await authTools(connectionId)), + }; + const tools = orchestrator.processTools({}); + const processedMessages = await processToolCalls( + { + messages: this.messages, + dataStream, + tools, + }, + {}, + ); + + const result = streamText({ + model: anthropic(env.OPENAI_MODEL || 'claude-3-5-haiku-latest'), + maxSteps: 10, + messages: processedMessages, + tools: rawTools, + onFinish, + onError: (error) => { + console.error('Error in streamText', error); + }, + system: await getPrompt(getPromptName(connectionId, EPrompts.Chat), AiChatPrompt('')), + }); + + result.mergeIntoDataStream(dataStream); + }, + }); + + return dataStreamResponse; + } + + private async tryCatchChat(fn: () => T | Promise) { + try { + return await fn(); + } catch (e) { + throw this.onError(e); + } + } + + private getAbortSignal(id: string): AbortSignal | undefined { + // Defensive check, since we're coercing message types at the moment + if (typeof id !== 'string') { + return undefined; + } + + if (!this.chatMessageAbortControllers.has(id)) { + this.chatMessageAbortControllers.set(id, new AbortController()); + } + + return this.chatMessageAbortControllers.get(id)?.signal; + } + + /** + * Remove an abort controller from the cache of pending message responses + */ + private removeAbortController(id: string) { + this.chatMessageAbortControllers.delete(id); + } + + broadcastChatMessage(message: OutgoingMessage, exclude?: string[]) { + this.broadcast(JSON.stringify(message), exclude); + } + + private cancelChatRequest(id: string) { + if (this.chatMessageAbortControllers.has(id)) { + const abortController = this.chatMessageAbortControllers.get(id); + abortController?.abort(); + } + } + + async onMessage(connection: Connection, message: WSMessage) { + if (typeof message === 'string') { + let data: IncomingMessage; + try { + data = JSON.parse(message) as IncomingMessage; + } catch (error) { + console.warn(error); + // silently ignore invalid messages for now + // TODO: log errors with log levels + return; + } + switch (data.type) { + case IncomingMessageType.UseChatRequest: { + if (data.init.method !== 'POST') break; + + const { body } = data.init; + + const { messages } = JSON.parse(body as string); + this.broadcastChatMessage( + { + type: OutgoingMessageType.ChatMessages, + messages, + }, + [connection.id], + ); + await this.persistMessages(messages, [connection.id]); + + const chatMessageId = data.id; + const abortSignal = this.getAbortSignal(chatMessageId); + + return this.tryCatchChat(async () => { + const response = await this.onChatMessage( + async ({ response }) => { + const finalMessages = appendResponseMessages({ + messages, + responseMessages: response.messages, + }); + + await this.persistMessages(finalMessages, [connection.id]); + this.removeAbortController(chatMessageId); + }, + abortSignal ? { abortSignal } : undefined, + ); + + if (response) { + await this.reply(data.id, response); + } else { + console.warn( + `[AIChatAgent] onChatMessage returned no response for chatMessageId: ${chatMessageId}`, + ); + this.broadcastChatMessage( + { + id: data.id, + type: OutgoingMessageType.UseChatResponse, + body: 'No response was generated by the agent.', + done: true, + }, + [connection.id], + ); + } + }); + } + case IncomingMessageType.ChatClear: { + this.destroyAbortControllers(); + void this.sql`delete from cf_ai_chat_agent_messages`; + this.messages = []; + this.broadcastChatMessage( + { + type: OutgoingMessageType.ChatClear, + }, + [connection.id], + ); + break; + } + case IncomingMessageType.ChatMessages: { + await this.persistMessages(data.messages, [connection.id]); + break; + } + case IncomingMessageType.ChatRequestCancel: { + this.cancelChatRequest(data.id); + break; + } + // case IncomingMessageType.Mail_List: { + // const result = await this.getThreadsFromDB({ + // labelIds: data.labelIds, + // folder: data.folder, + // q: data.query, + // max: data.maxResults, + // cursor: data.pageToken, + // }); + // this.currentFolder = data.folder; + // connection.send( + // JSON.stringify({ + // type: OutgoingMessageType.Mail_List, + // result, + // }), + // ); + // break; + // } + // case IncomingMessageType.Mail_Get: { + // const result = await this.getThreadFromDB(data.threadId); + // connection.send( + // JSON.stringify({ + // type: OutgoingMessageType.Mail_Get, + // result, + // threadId: data.threadId, + // }), + // ); + // break; + // } + } + } + } + + private async reply(id: string, response: Response) { + // now take chunks out from dataStreamResponse and send them to the client + return this.tryCatchChat(async () => { + for await (const chunk of response.body!) { + const body = decoder.decode(chunk); + + this.broadcastChatMessage({ + id, + type: OutgoingMessageType.UseChatResponse, + body, + done: false, + }); + } + + this.broadcastChatMessage({ + id, + type: OutgoingMessageType.UseChatResponse, + body: '', + done: true, + }); + }); + } + + private destroyAbortControllers() { + for (const controller of this.chatMessageAbortControllers.values()) { + controller?.abort(); + } + this.chatMessageAbortControllers.clear(); + } + + async onChatMessage( + onFinish: StreamTextOnFinishCallback<{}>, + options?: { + abortSignal: AbortSignal | undefined; + }, + ) { + return this.getDataStreamResponse(onFinish, options); + } } diff --git a/apps/server/src/routes/agent/mcp.ts b/apps/server/src/routes/agent/mcp.ts index 5f1b71379..19ed75147 100644 --- a/apps/server/src/routes/agent/mcp.ts +++ b/apps/server/src/routes/agent/mcp.ts @@ -166,7 +166,7 @@ export class ZeroMCP extends McpAgent, { use pageToken: s.pageToken, }); const content = await Promise.all( - result.threads.map(async (thread: any) => { + result.threads.map(async (thread) => { const loadedThread = await agent.getThread(thread.id); return [ { diff --git a/apps/server/src/routes/agent/rpc.ts b/apps/server/src/routes/agent/rpc.ts index 5cb4a1a5a..496d50f73 100644 --- a/apps/server/src/routes/agent/rpc.ts +++ b/apps/server/src/routes/agent/rpc.ts @@ -16,11 +16,11 @@ import type { CreateDraftData } from '../../lib/schemas'; import type { IOutgoingMessage } from '../../types'; import { RpcTarget } from 'cloudflare:workers'; -import { ZeroAgent } from '.'; +import { ZeroDriver } from '.'; -export class AgentRpcDO extends RpcTarget { +export class DriverRpcDO extends RpcTarget { constructor( - private mainDo: ZeroAgent, + private mainDo: ZeroDriver, private connectionId: string, ) { super(); @@ -176,9 +176,8 @@ export class AgentRpcDO extends RpcTarget { return await this.mainDo.getMessageAttachments(messageId); } - async setupAuth(connectionId: string) { - if (connectionId !== this.connectionId) console.warn('Oops, something doesnt add up.'); - return await this.mainDo.setupAuth(connectionId); + async setupAuth() { + return await this.mainDo.setupAuth(); } async broadcast(message: string) { diff --git a/apps/server/src/routes/agent/tools.ts b/apps/server/src/routes/agent/tools.ts index 08cb66f39..0b84e3638 100644 --- a/apps/server/src/routes/agent/tools.ts +++ b/apps/server/src/routes/agent/tools.ts @@ -3,9 +3,9 @@ import { perplexity } from '@ai-sdk/perplexity'; import { generateText, tool } from 'ai'; import { colors, GmailSearchAssistantSystemPrompt } from '../../lib/prompts'; +import { getZeroAgent } from '../../lib/server-utils'; import { anthropic } from '@ai-sdk/anthropic'; import { env } from 'cloudflare:workers'; -import type { ZeroAgent } from '../chat'; import { Tools } from '../../types'; import { z } from 'zod'; @@ -112,7 +112,7 @@ export const getEmbeddingVector = async ( * * The tag format must be exactly: */ -const getEmail = (_: ZeroAgent) => +const getEmail = () => tool({ description: 'Return a placeholder tag for a specific email thread by ID', parameters: z.object({ @@ -155,7 +155,7 @@ const composeEmailTool = (connectionId: string) => }, }); -// const listEmails = (agent: ZeroAgent) => +// const listEmails = (connectionId: string) => // tool({ // description: 'List emails in a specific folder', // parameters: z.object({ @@ -173,19 +173,20 @@ const composeEmailTool = (connectionId: string) => // }, // }); -const markAsRead = (agent: ZeroAgent) => +const markAsRead = (connectionId: string) => tool({ description: 'Mark emails as read', parameters: z.object({ threadIds: z.array(z.string()).describe('The IDs of the threads to mark as read'), }), execute: async ({ threadIds }) => { - await agent.markAsRead(threadIds); + const driver = await getZeroAgent(connectionId); + await driver.markAsRead(threadIds); return { threadIds, success: true }; }, }); -// const inboxRag = (agent: ZeroAgent, dataStream?: DataStreamWriter) => +// const inboxRag = (connectionId: string, dataStream?: DataStreamWriter) => // tool({ // description: 'Search the inbox for emails', // parameters: z.object({ @@ -197,19 +198,20 @@ const markAsRead = (agent: ZeroAgent) => // }, // }); -const markAsUnread = (agent: ZeroAgent) => +const markAsUnread = (connectionId: string) => tool({ description: 'Mark emails as unread', parameters: z.object({ threadIds: z.array(z.string()).describe('The IDs of the threads to mark as unread'), }), execute: async ({ threadIds }) => { - await agent.markAsUnread(threadIds); + const driver = await getZeroAgent(connectionId); + await driver.markAsUnread(threadIds); return { threadIds, success: true }; }, }); -const modifyLabels = (agent: ZeroAgent) => +const modifyLabels = (connectionId: string) => tool({ description: 'Modify labels on emails', parameters: z.object({ @@ -220,21 +222,23 @@ const modifyLabels = (agent: ZeroAgent) => }), }), execute: async ({ threadIds, options }) => { - await agent.modifyLabels(threadIds, options.addLabels, options.removeLabels); + const driver = await getZeroAgent(connectionId); + await driver.modifyLabels(threadIds, options.addLabels, options.removeLabels); return { threadIds, options, success: true }; }, }); -const getUserLabels = (agent: ZeroAgent) => +const getUserLabels = (connectionId: string) => tool({ description: 'Get all user labels', parameters: z.object({}), execute: async () => { - return await agent.getUserLabels(); + const driver = await getZeroAgent(connectionId); + return await driver.getUserLabels(); }, }); -const sendEmail = (agent: ZeroAgent) => +const sendEmail = (connectionId: string) => tool({ description: 'Send a new email', parameters: z.object({ @@ -268,16 +272,17 @@ const sendEmail = (agent: ZeroAgent) => }), execute: async (data) => { try { + const driver = await getZeroAgent(connectionId); const { draftId, ...mail } = data; if (draftId) { - await agent.sendDraft(draftId, { + await driver.sendDraft(draftId, { ...mail, attachments: [], headers: {}, }); } else { - await agent.create({ + await driver.create({ ...mail, attachments: [], headers: {}, @@ -294,7 +299,7 @@ const sendEmail = (agent: ZeroAgent) => }, }); -const createLabel = (agent: ZeroAgent) => +const createLabel = (connectionId: string) => tool({ description: 'Create a new label with custom colors, if it does nto exist already', parameters: z.object({ @@ -313,43 +318,47 @@ const createLabel = (agent: ZeroAgent) => }), }), execute: async ({ name, backgroundColor, textColor }) => { - await agent.createLabel({ name, color: { backgroundColor, textColor } }); + const driver = await getZeroAgent(connectionId); + await driver.createLabel({ name, color: { backgroundColor, textColor } }); return { name, backgroundColor, textColor, success: true }; }, }); -const bulkDelete = (agent: ZeroAgent) => +const bulkDelete = (connectionId: string) => tool({ description: 'Move multiple emails to trash by adding the TRASH label', parameters: z.object({ threadIds: z.array(z.string()).describe('Array of email IDs to move to trash'), }), execute: async ({ threadIds }) => { - await agent.modifyLabels(threadIds, ['TRASH'], []); + const driver = await getZeroAgent(connectionId); + await driver.modifyLabels(threadIds, ['TRASH'], []); return { threadIds, success: true }; }, }); -const bulkArchive = (agent: ZeroAgent) => +const bulkArchive = (connectionId: string) => tool({ description: 'Move multiple emails to the archive by removing the INBOX label', parameters: z.object({ threadIds: z.array(z.string()).describe('Array of email IDs to move to archive'), }), execute: async ({ threadIds }) => { - await agent.modifyLabels(threadIds, [], ['INBOX']); + const driver = await getZeroAgent(connectionId); + await driver.modifyLabels(threadIds, [], ['INBOX']); return { threadIds, success: true }; }, }); -const deleteLabel = (agent: ZeroAgent) => +const deleteLabel = (connectionId: string) => tool({ description: "Delete a label from the user's account", parameters: z.object({ id: z.string().describe('The ID of the label to delete'), }), execute: async ({ id }) => { - await agent.deleteLabel(id); + const driver = await getZeroAgent(connectionId); + await driver.deleteLabel(id); return { id, success: true }; }, }); @@ -397,19 +406,19 @@ const buildGmailSearchQuery = () => }, }); -export const tools = async (agent: ZeroAgent, connectionId: string) => { +export const tools = async (connectionId: string) => { return { - [Tools.GetThread]: getEmail(agent), + [Tools.GetThread]: getEmail(), [Tools.ComposeEmail]: composeEmailTool(connectionId), - [Tools.MarkThreadsRead]: markAsRead(agent), - [Tools.MarkThreadsUnread]: markAsUnread(agent), - [Tools.ModifyLabels]: modifyLabels(agent), - [Tools.GetUserLabels]: getUserLabels(agent), - [Tools.SendEmail]: sendEmail(agent), - [Tools.CreateLabel]: createLabel(agent), - [Tools.BulkDelete]: bulkDelete(agent), - [Tools.BulkArchive]: bulkArchive(agent), - [Tools.DeleteLabel]: deleteLabel(agent), + [Tools.MarkThreadsRead]: markAsRead(connectionId), + [Tools.MarkThreadsUnread]: markAsUnread(connectionId), + [Tools.ModifyLabels]: modifyLabels(connectionId), + [Tools.GetUserLabels]: getUserLabels(connectionId), + [Tools.SendEmail]: sendEmail(connectionId), + [Tools.CreateLabel]: createLabel(connectionId), + [Tools.BulkDelete]: bulkDelete(connectionId), + [Tools.BulkArchive]: bulkArchive(connectionId), + [Tools.DeleteLabel]: deleteLabel(connectionId), [Tools.WebSearch]: tool({ description: 'Search the web for information using Perplexity AI', parameters: z.object({ diff --git a/apps/server/wrangler.jsonc b/apps/server/wrangler.jsonc index 34cd860ac..8fc8c11d7 100644 --- a/apps/server/wrangler.jsonc +++ b/apps/server/wrangler.jsonc @@ -39,6 +39,10 @@ "name": "ZERO_DB", "class_name": "ZeroDB", }, + { + "name": "ZERO_DRIVER", + "class_name": "ZeroDriver", + }, ], }, "queues": { @@ -78,6 +82,10 @@ "tag": "v4", "deleted_classes": ["DurableMailbox"], }, + { + "tag": "v5", + "new_sqlite_classes": ["ZeroDriver"], + }, ], "observability": { @@ -167,6 +175,10 @@ "name": "ZERO_DB", "class_name": "ZeroDB", }, + { + "name": "ZERO_DRIVER", + "class_name": "ZeroDriver", + }, ], }, "r2_buckets": [ @@ -216,6 +228,10 @@ "tag": "v5", "deleted_classes": ["DurableMailbox"], }, + { + "tag": "v6", + "new_sqlite_classes": ["ZeroDriver"], + }, ], "observability": { "enabled": true, @@ -312,6 +328,10 @@ "name": "ZERO_DB", "class_name": "ZeroDB", }, + { + "name": "ZERO_DRIVER", + "class_name": "ZeroDriver", + }, ], }, "queues": { @@ -355,6 +375,10 @@ "tag": "v5", "deleted_classes": ["DurableMailbox"], }, + { + "tag": "v6", + "new_sqlite_classes": ["ZeroDriver"], + }, ], "vars": { "NODE_ENV": "production",