diff --git a/apps/mail/components/ui/ai-sidebar.tsx b/apps/mail/components/ui/ai-sidebar.tsx index cd3872ebb..4358e808a 100644 --- a/apps/mail/components/ui/ai-sidebar.tsx +++ b/apps/mail/components/ui/ai-sidebar.tsx @@ -352,6 +352,9 @@ function AISidebar({ className }: AISidebarProps) { }); const chatState = useAgentChat({ + getInitialMessages: async () => { + return []; + }, agent, maxSteps: 10, body: { @@ -470,7 +473,7 @@ function AISidebar({ className }: AISidebarProps) {
}, }); +const buildGmailSearchQuery = () => + tool({ + description: 'Build Gmail search query using AI assistance', + parameters: z.object({ + query: z.string(), + }), + execute: async ({ query }) => { + const result = await generateText({ + model: anthropic(env.OPENAI_MODEL || 'claude-3-5-haiku-latest'), + system: GmailSearchAssistantSystemPrompt(), + prompt: query, + }); + return result.text; + }, + }); + export const tools = async (agent: ZeroAgent, connectionId: string) => { return { [Tools.GetThread]: getEmail(agent), @@ -399,6 +416,7 @@ export const tools = async (agent: ZeroAgent, connectionId: string) => { query: z.string().describe('The query to search the web for'), }), }), + [Tools.BuildGmailSearchQuery]: buildGmailSearchQuery(), [Tools.InboxRag]: tool({ description: 'Search the inbox for emails using natural language. Returns only an array of threadIds.', @@ -406,6 +424,5 @@ export const tools = async (agent: ZeroAgent, connectionId: string) => { query: z.string().describe('The query to search the inbox for'), }), }), - // ...(await getGoogleTools(connectionId)), }; }; diff --git a/apps/server/src/routes/chat.ts b/apps/server/src/routes/chat.ts index 66457b77b..700539505 100644 --- a/apps/server/src/routes/chat.ts +++ b/apps/server/src/routes/chat.ts @@ -13,8 +13,6 @@ */ import { streamText, - generateObject, - tool, type StreamTextOnFinishCallback, createDataStreamResponse, generateText, @@ -25,6 +23,7 @@ import { GmailSearchAssistantSystemPrompt, AiChatPrompt, } from '../lib/prompts'; +import { DurableObjectOAuthClientProvider } from 'agents/mcp/do-oauth-client-provider'; import { EPrompts, type IOutgoingMessage, type ParsedMessage } from '../types'; import type { IGetThreadResponse, MailManager } from '../lib/driver/types'; import { connectionToDriver, getZeroAgent } from '../lib/server-utils'; @@ -32,6 +31,7 @@ import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { type Connection, type WSMessage } from 'agents'; import { ToolOrchestrator } from './agent/orchestrator'; import type { CreateDraftData } from '../lib/schemas'; +import { withRetry } from '../lib/gmail-rate-limit'; import { env, RpcTarget } from 'cloudflare:workers'; import { AIChatAgent } from 'agents/ai-chat-agent'; import { tools as authTools } from './agent/tools'; @@ -44,8 +44,6 @@ import { getPrompt } from '../lib/brain'; import { FOLDERS } from '../lib/utils'; import { and, eq } from 'drizzle-orm'; import { McpAgent } from 'agents/mcp'; - -import { withRetry } from '../lib/gmail-rate-limit'; import { createDb } from '../db'; import { Effect } from 'effect'; import { z } from 'zod'; @@ -164,10 +162,6 @@ export class AgentRpcDO extends RpcTarget { return await this.mainDo.bulkArchive(threadIds); } - async buildGmailSearchQuery(query: string) { - return await this.mainDo.buildGmailSearchQuery(query); - } - async rawListThreads(params: { folder: string; query?: string; @@ -364,6 +358,22 @@ export class ZeroAgent extends AIChatAgent { return new AgentRpcDO(this, connectionId); } + async registerZeroMCP() { + await this.mcp.connect(env.VITE_PUBLIC_BACKEND_URL + '/sse', { + transport: { + authProvider: new DurableObjectOAuthClientProvider( + this.ctx.storage, + 'zero-mcp', + env.VITE_PUBLIC_BACKEND_URL, + ), + }, + }); + } + + onStart(): void | Promise { + // this.registerZeroMCP(); + } + private getDataStreamResponse( onFinish: StreamTextOnFinishCallback<{}>, _?: { @@ -383,11 +393,12 @@ export class ZeroAgent extends AIChatAgent { } } const orchestrator = new ToolOrchestrator(dataStream, connectionId); + // const mcpTools = await this.mcp.unstable_getAITools(); + const rawTools = { ...(await authTools(this, connectionId)), - buildGmailSearchQuery, }; - const tools = orchestrator.processTools(rawTools); + const tools = orchestrator.processTools({}); const processedMessages = await processToolCalls( { messages: this.messages, @@ -401,8 +412,11 @@ export class ZeroAgent extends AIChatAgent { model: anthropic(env.OPENAI_MODEL || 'claude-3-5-haiku-latest'), maxSteps: 10, messages: processedMessages, - tools, + tools: rawTools, onFinish, + onError: (error) => { + console.error('Error in streamText', error); + }, system: await getPrompt(getPromptName(connectionId, EPrompts.Chat), AiChatPrompt('')), }); @@ -587,7 +601,6 @@ export class ZeroAgent extends AIChatAgent { return this.tryCatchChat(async () => { for await (const chunk of response.body!) { const body = decoder.decode(chunk); - console.log('reply', body); this.broadcastChatMessage({ id, @@ -743,15 +756,6 @@ export class ZeroAgent extends AIChatAgent { }); } - async buildGmailSearchQuery(query: string) { - const result = await generateText({ - model: anthropic(env.OPENAI_MODEL || 'claude-3-5-haiku-latest'), - system: GmailSearchAssistantSystemPrompt(), - prompt: query, - }); - return result.text; - } - async updateLabel( id: string, label: { name: string; color?: { backgroundColor: string; textColor: string } }, @@ -1347,44 +1351,59 @@ export class ZeroMCP extends McpAgent { throw new Error('Unauthorized'); } this.activeConnectionId = _connection.id; - const agent = await getZeroAgent(_connection.id); + this.server.registerTool( + 'getConnections', + { + description: + 'Use this tool to get all connections for the user. This helps you know what accounts(connections) the user has available.', + inputSchema: {}, + }, + async () => { + const connections = await db.query.connection.findMany({ + where: eq(connection.userId, this.props.userId), + }); + return { + content: connections.map((c) => ({ + type: 'text', + text: `Email: ${c.email} | Provider: ${c.providerId}`, + })), + }; + }, + ); - this.server.tool('getConnections', async () => { - const connections = await db.query.connection.findMany({ - where: eq(connection.userId, this.props.userId), - }); - return { - content: connections.map((c) => ({ - type: 'text', - text: `Email: ${c.email} | Provider: ${c.providerId}`, - })), - }; - }); + this.server.registerTool( + 'getActiveConnection', + { + description: 'Get the currently active email connection', + }, + async () => { + if (!this.activeConnectionId) { + throw new Error('No active connection'); + } + const _connection = await db.query.connection.findFirst({ + where: eq(connection.id, this.activeConnectionId), + }); + if (!_connection) { + throw new Error('Connection not found'); + } + return { + content: [ + { + type: 'text' as const, + text: `Email: ${_connection.email} | Provider: ${_connection.providerId}`, + }, + ], + }; + }, + ); - this.server.tool('getActiveConnection', async () => { - if (!this.activeConnectionId) { - throw new Error('No active connection'); - } - const _connection = await db.query.connection.findFirst({ - where: eq(connection.id, this.activeConnectionId), - }); - if (!_connection) { - throw new Error('Connection not found'); - } - return { - content: [ - { - type: 'text' as const, - text: `Email: ${_connection.email} | Provider: ${_connection.providerId}`, - }, - ], - }; - }); - - this.server.tool( + this.server.registerTool( 'setActiveConnection', { - email: z.string(), + description: 'Set the active email connection by email address', + inputSchema: { + email: z.string(), + }, }, async (s) => { const _connection = await db.query.connection.findFirst({ @@ -1405,10 +1424,13 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + this.server.registerTool( 'buildGmailSearchQuery', { - query: z.string(), + description: 'Build Gmail search query using AI assistance', + inputSchema: { + query: z.string(), + }, }, async (s) => { const result = await generateText({ @@ -1427,14 +1449,19 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + const agent = await getZeroAgent(_connection.id); + + this.server.registerTool( 'listThreads', { - folder: z.string().default(FOLDERS.INBOX), - query: z.string().optional(), - maxResults: z.number().optional().default(5), - labelIds: z.array(z.string()).optional(), - pageToken: z.string().optional(), + description: 'List email threads with optional filters and pagination', + inputSchema: { + folder: z.string().default(FOLDERS.INBOX), + query: z.string().optional(), + maxResults: z.number().optional().default(5), + labelIds: z.array(z.string()).optional(), + pageToken: z.string().optional(), + }, }, async (s) => { const result = await agent.listThreads({ @@ -1472,10 +1499,13 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + this.server.registerTool( 'getThread', { - threadId: z.string(), + description: 'Get detailed information about a specific email thread', + inputSchema: { + threadId: z.string(), + }, }, async (s) => { const thread = await agent.getThread(s.threadId); @@ -1490,7 +1520,7 @@ export class ZeroMCP extends McpAgent { }, { type: 'text' as const, - text: `Latest Message Sender: ${thread.latest?.sender}`, + text: `Latest Message Sender: ${thread.latest?.sender.name} <${thread.latest?.sender.email}>`, }, { type: 'text' as const, @@ -1501,40 +1531,19 @@ export class ZeroMCP extends McpAgent { text: `Thread ID: ${s.threadId}`, }, ]; - const response = await env.VECTORIZE.getByIds([s.threadId]); - if (response.length && response?.[0]?.metadata?.['summary']) { - const content = response[0].metadata['summary'] as string; - const shortResponse = await env.AI.run('@cf/facebook/bart-large-cnn', { - input_text: content, - }); - return { - content: [ - ...initialResponse, - { - type: 'text', - text: `Subject: ${thread.latest?.subject}`, - }, - { - type: 'text', - text: `Long Summary: ${content}`, - }, - { - type: 'text', - text: `Short Summary: ${shortResponse.summary}`, - }, - ], - }; - } return { content: initialResponse, }; }, ); - this.server.tool( + this.server.registerTool( 'markThreadsRead', { - threadIds: z.array(z.string()), + description: 'Mark email threads as read', + inputSchema: { + threadIds: z.array(z.string()), + }, }, async (s) => { await agent.modifyLabels(s.threadIds, [], ['UNREAD']); @@ -1549,10 +1558,13 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + this.server.registerTool( 'markThreadsUnread', { - threadIds: z.array(z.string()), + description: 'Mark email threads as unread', + inputSchema: { + threadIds: z.array(z.string()), + }, }, async (s) => { await agent.modifyLabels(s.threadIds, ['UNREAD'], []); @@ -1567,12 +1579,15 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + this.server.registerTool( 'modifyLabels', { - threadIds: z.array(z.string()), - addLabelIds: z.array(z.string()), - removeLabelIds: z.array(z.string()), + description: 'Add or remove labels from email threads', + inputSchema: { + threadIds: z.array(z.string()), + addLabelIds: z.array(z.string()), + removeLabelIds: z.array(z.string()), + }, }, async (s) => { await agent.modifyLabels(s.threadIds, s.addLabelIds, s.removeLabelIds); @@ -1587,35 +1602,49 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool('getCurrentDate', async () => { - return { - content: [ - { - type: 'text', - text: getCurrentDateContext(), - }, - ], - }; - }); + this.server.registerTool( + 'getCurrentDate', + { + description: 'Get the current date and time', + inputSchema: z.object({}).shape, + }, + async () => { + return { + content: [ + { + type: 'text', + text: getCurrentDateContext(), + }, + ], + }; + }, + ); - this.server.tool('getUserLabels', async () => { - const labels = await agent.getUserLabels(); - return { - content: [ - { - type: 'text', - text: labels - .map((label) => `Name: ${label.name} ID: ${label.id} Color: ${label.color}`) - .join('\n'), - }, - ], - }; - }); + this.server.registerTool( + 'getUserLabels', + { description: 'Get all available labels for the user' }, + async () => { + const labels = await agent.getUserLabels(); + return { + content: [ + { + type: 'text', + text: labels + .map((label) => `Name: ${label.name} ID: ${label.id} Color: ${label.color}`) + .join('\n'), + }, + ], + }; + }, + ); - this.server.tool( + this.server.registerTool( 'getLabel', { - id: z.string(), + description: 'Get details about a specific label', + inputSchema: { + id: z.string(), + }, }, async (s) => { const label = await agent.getLabel(s.id); @@ -1634,12 +1663,15 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( + this.server.registerTool( 'createLabel', { - name: z.string(), - backgroundColor: z.string().optional(), - textColor: z.string().optional(), + description: 'Create a new email label', + inputSchema: { + name: z.string(), + backgroundColor: z.string().optional(), + textColor: z.string().optional(), + }, }, async (s) => { try { @@ -1675,83 +1707,71 @@ export class ZeroMCP extends McpAgent { }, ); - this.server.tool( - 'bulkDelete', - { - threadIds: z.array(z.string()), - }, - async (s) => { - try { - await agent.modifyLabels(s.threadIds, ['TRASH'], ['INBOX']); - return { - content: [ - { - type: 'text', - text: 'Threads moved to trash', - }, - ], - }; - } catch (e) { - console.error(e); - return { - content: [ - { - type: 'text', - text: 'Failed to move threads to trash', - }, - ], - }; - } - }, - ); + // this.server.registerTool( + // 'bulkDelete', + // { + // description: 'Move multiple threads to trash', + // inputSchema: { + // threadIds: z.array(z.string()), + // }, + // }, + // async (s) => { + // try { + // await agent.modifyLabels(s.threadIds, ['TRASH'], ['INBOX']); + // return { + // content: [ + // { + // type: 'text', + // text: 'Threads moved to trash', + // }, + // ], + // }; + // } catch (e) { + // console.error(e); + // return { + // content: [ + // { + // type: 'text', + // text: 'Failed to move threads to trash', + // }, + // ], + // }; + // } + // }, + // ); - this.server.tool( - 'bulkArchive', - { - threadIds: z.array(z.string()), - }, - async (s) => { - try { - await agent.modifyLabels(s.threadIds, [], ['INBOX']); - return { - content: [ - { - type: 'text', - text: 'Threads archived', - }, - ], - }; - } catch (e) { - console.error(e); - return { - content: [ - { - type: 'text', - text: 'Failed to archive threads', - }, - ], - }; - } - }, - ); + // this.server.registerTool( + // 'bulkArchive', + // { + // description: 'Archive multiple email threads', + // inputSchema: { + // threadIds: z.array(z.string()), + // }, + // }, + // async (s) => { + // try { + // await agent.modifyLabels(s.threadIds, [], ['INBOX']); + // return { + // content: [ + // { + // type: 'text', + // text: 'Threads archived', + // }, + // ], + // }; + // } catch (e) { + // console.error(e); + // return { + // content: [ + // { + // type: 'text', + // text: 'Failed to archive threads', + // }, + // ], + // }; + // } + // }, + // ); this.ctx.waitUntil(conn.end()); } } - -const buildGmailSearchQuery = tool({ - description: 'Build a Gmail search query', - parameters: z.object({ - query: z.string().describe('The search query to build, provided in natural language'), - }), - execute: async ({ query }) => { - const result = await generateObject({ - model: openai(env.OPENAI_MODEL || 'gpt-4o'), - system: GmailSearchAssistantSystemPrompt(), - prompt: query, - schema: z.object({ - query: z.string(), - }), - }); - return result.object; - }, -}); diff --git a/apps/server/src/types.ts b/apps/server/src/types.ts index 4104084b9..ae5abc42b 100644 --- a/apps/server/src/types.ts +++ b/apps/server/src/types.ts @@ -226,6 +226,7 @@ export enum Tools { AskZeroThread = 'askZeroThread', WebSearch = 'webSearch', InboxRag = 'inboxRag', + BuildGmailSearchQuery = 'buildGmailSearchQuery', } export type AppContext = Context<{ Bindings: Env }>; diff --git a/apps/server/wrangler.jsonc b/apps/server/wrangler.jsonc index 407d989a8..123eb715d 100644 --- a/apps/server/wrangler.jsonc +++ b/apps/server/wrangler.jsonc @@ -101,7 +101,7 @@ "VOICE_SECRET": "1234567890", "GOOGLE_S_ACCOUNT": "{}", "DROP_AGENT_TABLES": "false", - "THREAD_SYNC_MAX_COUNT": "40", + "THREAD_SYNC_MAX_COUNT": "5", "THREAD_SYNC_LOOP": "false", "DISABLE_WORKFLOWS": "true", "AUTORAG_ID": "",