diff --git a/apps/server/src/lib/auth.ts b/apps/server/src/lib/auth.ts index 2b1bc9391..f4c2dad42 100644 --- a/apps/server/src/lib/auth.ts +++ b/apps/server/src/lib/auth.ts @@ -122,7 +122,7 @@ const connectionHandlerHook = async (account: Account) => { expiresAt: new Date(Date.now() + (account.accessTokenExpiresAt?.getTime() || 3600000)), }; - const db = getZeroDB(account.userId); + const db = await getZeroDB(account.userId); const [result] = await db.createConnection( account.providerId as EProviders, userInfo.address, @@ -189,7 +189,7 @@ export const createAuth = () => { }, beforeDelete: async (user, request) => { if (!request) throw new APIError('BAD_REQUEST', { message: 'Request object is missing' }); - const db = getZeroDB(user.id); + const db = await getZeroDB(user.id); const connections = await db.findManyConnections(); const context = getContext(); try { @@ -287,7 +287,7 @@ export const createAuth = () => { const newSession = ctx.context.newSession; if (newSession) { // Check if user already has settings - const db = getZeroDB(newSession.user.id); + const db = await getZeroDB(newSession.user.id); const existingSettings = await db.findUserSettings(); if (!existingSettings) { diff --git a/apps/server/src/lib/driver/utils.ts b/apps/server/src/lib/driver/utils.ts index 6b26a4521..618e8ed30 100644 --- a/apps/server/src/lib/driver/utils.ts +++ b/apps/server/src/lib/driver/utils.ts @@ -3,7 +3,6 @@ import { getContext } from 'hono/context-storage'; import type { gmail_v1 } from '@googleapis/gmail'; import type { HonoContext } from '../../ctx'; - import { toByteArray } from 'base64-js'; export const FatalErrors = ['invalid_grant']; @@ -15,7 +14,7 @@ export const deleteActiveConnection = async () => { if (!session) return console.log('No session found'); try { await c.var.auth.api.signOut({ headers: c.req.raw.headers }); - const db = getZeroDB(session.user.id); + const db = await getZeroDB(session.user.id); await db.deleteActiveConnection(activeConnection.id); } catch (error) { console.error('Server: Error deleting connection:', error); diff --git a/apps/server/src/lib/notes-manager.ts b/apps/server/src/lib/notes-manager.ts index 9f77dcd92..9c2f8d685 100644 --- a/apps/server/src/lib/notes-manager.ts +++ b/apps/server/src/lib/notes-manager.ts @@ -5,7 +5,7 @@ export class NotesManager { constructor() {} async getThreadNotes(userId: string, threadId: string): Promise<(typeof note.$inferSelect)[]> { - const db = getZeroDB(userId); + const db = await getZeroDB(userId); return await db.findManyNotesByThreadId(threadId); } @@ -16,8 +16,8 @@ export class NotesManager { color: string = 'default', isPinned: boolean = false, ): Promise { - try{ - const db = getZeroDB(userId); + try { + const db = await getZeroDB(userId); const highestOrder = await db.findHighestNoteOrder(); const id = crypto.randomUUID(); @@ -46,7 +46,7 @@ export class NotesManager { Omit >, ): Promise { - const db = getZeroDB(userId); + const db = await getZeroDB(userId); const existingNote = await db.findNoteById(noteId); if (!existingNote) { @@ -62,7 +62,7 @@ export class NotesManager { } async deleteNote(userId: string, noteId: string) { - const db = getZeroDB(userId); + const db = await getZeroDB(userId); try { await db.deleteNote(noteId); return true; @@ -82,7 +82,7 @@ export class NotesManager { const noteIds = notes.map((n) => n.id); - const db = getZeroDB(userId); + const db = await getZeroDB(userId); const userNotes = await db.findManyNotesByIds(noteIds); const foundNoteIds = new Set(userNotes.map((n) => n.id)); diff --git a/apps/server/src/lib/server-utils.ts b/apps/server/src/lib/server-utils.ts index fe723c8f4..0aed7bc36 100644 --- a/apps/server/src/lib/server-utils.ts +++ b/apps/server/src/lib/server-utils.ts @@ -4,9 +4,9 @@ import type { HonoContext } from '../ctx'; import { env } from 'cloudflare:workers'; import { createDriver } from './driver'; -export const getZeroDB = (userId: string) => { +export const getZeroDB = async (userId: string) => { const stub = env.ZERO_DB.get(env.ZERO_DB.idFromName(userId)); - const rpcTarget = stub.setMetaData(userId); + const rpcTarget = await stub.setMetaData(userId); return rpcTarget; }; @@ -17,12 +17,17 @@ export const getZeroAgent = async (connectionId: string) => { return rpcTarget; }; +export const getZeroSocketAgent = async (connectionId: string) => { + const stub = env.ZERO_AGENT.get(env.ZERO_AGENT.idFromName(connectionId)); + return stub; +}; + export const getActiveConnection = async () => { const c = getContext(); const { sessionUser } = c.var; if (!sessionUser) throw new Error('Session Not Found'); - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); const userData = await db.findUser(); diff --git a/apps/server/src/main.ts b/apps/server/src/main.ts index cfe648052..ea45b65a5 100644 --- a/apps/server/src/main.ts +++ b/apps/server/src/main.ts @@ -518,7 +518,7 @@ export default class extends WorkerEntrypoint { const userId = payload.sub; if (userId) { - const db = getZeroDB(userId); + const db = await getZeroDB(userId); c.set('sessionUser', await db.findUser()); } } diff --git a/apps/server/src/routes/agent/index.ts b/apps/server/src/routes/agent/index.ts index 6123dd71d..af4cda0ca 100644 --- a/apps/server/src/routes/agent/index.ts +++ b/apps/server/src/routes/agent/index.ts @@ -36,7 +36,7 @@ import { import type { MailManager, IGetThreadResponse, IGetThreadsResponse } from '../../lib/driver/types'; import { DurableObjectOAuthClientProvider } from 'agents/mcp/do-oauth-client-provider'; import { AiChatPrompt, GmailSearchAssistantSystemPrompt } from '../../lib/prompts'; -import { connectionToDriver } from '../../lib/server-utils'; +import { connectionToDriver, getZeroSocketAgent } from '../../lib/server-utils'; import type { CreateDraftData } from '../../lib/schemas'; import { withRetry } from '../../lib/gmail-rate-limit'; import { getPrompt } from '../../pipelines.effect'; @@ -58,8 +58,8 @@ import { Effect } from 'effect'; const decoder = new TextDecoder(); -const shouldDropTables = env.DROP_AGENT_TABLES === 'true'; -const maxCount = parseInt(env.THREAD_SYNC_MAX_COUNT || '10', 10); +const shouldDropTables = false; +const maxCount = 20; const shouldLoop = env.THREAD_SYNC_LOOP !== 'false'; export class ZeroDriver extends AIChatAgent { private foldersInSync: Map = new Map(); @@ -87,6 +87,7 @@ export class ZeroDriver extends AIChatAgent { async setMetaData(connectionId: string) { await this.setName(connectionId); + this.agent = await getZeroSocketAgent(connectionId); return new DriverRpcDO(this, connectionId); } @@ -154,7 +155,7 @@ export class ZeroDriver extends AIChatAgent { } async onConnect() { - await this.setupAuth(); + if (!this.driver) await this.setupAuth(); } public async setupAuth() { @@ -449,6 +450,13 @@ export class ZeroDriver extends AIChatAgent { return count[0]['COUNT(*)'] as number; } + async getFolderThreadCount(folder: string) { + const count = this.sql`SELECT COUNT(*) FROM threads WHERE EXISTS ( + SELECT 1 FROM json_each(latest_label_ids) WHERE value = ${folder} + )`; + return count[0]['COUNT(*)'] as number; + } + async syncThreads(folder: string) { if (!this.driver) { console.error('No driver available for syncThreads'); @@ -618,6 +626,16 @@ export class ZeroDriver extends AIChatAgent { const { labelIds = [], folder, q, maxResults = 50, pageToken } = params; try { + const folderThreadCount = (await this.count()).find((c) => c.label === folder)?.count; + const currentThreadCount = await this.getThreadCount(); + + console.log('folderThreadCount', folderThreadCount, folder); + console.log('currentThreadCount', currentThreadCount); + + if (folderThreadCount && folderThreadCount > currentThreadCount && folder) { + this.ctx.waitUntil(this.syncThreads(folder)); + } + // Build WHERE conditions const whereConditions: string[] = []; @@ -794,6 +812,8 @@ export class ZeroDriver extends AIChatAgent { `; if (!result || result.length === 0) { + const res = await this.queue('syncThread', { threadId: id }); + console.log('res', res); return { messages: [], latest: undefined, @@ -827,7 +847,9 @@ export class ZeroDriver extends AIChatAgent { async unsnoozeThreadsHandler(payload: ISnoozeBatch) { const { connectionId, threadIds, keyNames } = payload; try { - await this.setupAuth(); + if (!this.driver) { + await this.setupAuth(); + } if (threadIds.length) { await this.modifyLabels(threadIds, ['INBOX'], ['SNOOZED']); diff --git a/apps/server/src/routes/agent/rpc.ts b/apps/server/src/routes/agent/rpc.ts index 53fa4f650..c7f279ccd 100644 --- a/apps/server/src/routes/agent/rpc.ts +++ b/apps/server/src/routes/agent/rpc.ts @@ -220,7 +220,7 @@ export class DriverRpcDO extends RpcTarget { return await this.mainDo.searchThreads(params); } - async queue(callbackName: string, payload: unknown): Promise { + async queue(callbackName: keyof ZeroDriver, payload: unknown): Promise { const queueFn = this.mainDo.queue; if (typeof queueFn !== 'function') { throw new Error('queue method not implemented on mainDo'); diff --git a/apps/server/src/trpc/routes/connections.ts b/apps/server/src/trpc/routes/connections.ts index 0ddc6ff37..97a4c31a7 100644 --- a/apps/server/src/trpc/routes/connections.ts +++ b/apps/server/src/trpc/routes/connections.ts @@ -15,7 +15,7 @@ export const connectionsRouter = router({ ) .query(async ({ ctx }) => { const { sessionUser } = ctx; - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); const connections = await db.findManyConnections(); const disconnectedIds = connections @@ -41,7 +41,7 @@ export const connectionsRouter = router({ .mutation(async ({ input, ctx }) => { const { connectionId } = input; const user = ctx.sessionUser; - const db = getZeroDB(user.id); + const db = await getZeroDB(user.id); const foundConnection = await db.findUserConnection(connectionId); if (!foundConnection) throw new TRPCError({ code: 'NOT_FOUND' }); await db.updateUser({ defaultConnectionId: connectionId }); @@ -51,7 +51,7 @@ export const connectionsRouter = router({ .mutation(async ({ input, ctx }) => { const { connectionId } = input; const user = ctx.sessionUser; - const db = getZeroDB(user.id); + const db = await getZeroDB(user.id); await db.deleteConnection(connectionId); const activeConnection = await getActiveConnection(); diff --git a/apps/server/src/trpc/routes/mail.ts b/apps/server/src/trpc/routes/mail.ts index f43fe8a10..fee41e77c 100644 --- a/apps/server/src/trpc/routes/mail.ts +++ b/apps/server/src/trpc/routes/mail.ts @@ -1,19 +1,19 @@ -import { updateWritingStyleMatrix } from '../../services/writing-style-service'; -import { activeDriverProcedure, router, privateProcedure } from '../trpc'; import { IGetThreadResponseSchema, IGetThreadsResponseSchema, type IGetThreadsResponse, } from '../../lib/driver/types'; +import { updateWritingStyleMatrix } from '../../services/writing-style-service'; +import { activeDriverProcedure, router, privateProcedure } from '../trpc'; import { processEmailHtml } from '../../lib/email-processor'; import { defaultPageSize, FOLDERS } from '../../lib/utils'; import { serializedFileSchema } from '../../lib/schemas'; import type { DeleteAllSpamResponse } from '../../types'; import { getZeroAgent } from '../../lib/server-utils'; +import { env } from 'cloudflare:workers'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; -import { env } from 'cloudflare:workers'; const senderSchema = z.object({ name: z.string().optional(), @@ -75,12 +75,16 @@ export const mailRouter = router({ const { activeConnection } = ctx; const agent = await getZeroAgent(activeConnection.id); + console.debug('[listThreads] input:', { folder, maxResults, cursor, q, labelIds }); + if (folder === FOLDERS.DRAFT) { + console.debug('[listThreads] Listing drafts'); const drafts = await agent.listDrafts({ q, maxResults, pageToken: cursor, }); + console.debug('[listThreads] Drafts result:', drafts); return drafts; } @@ -89,7 +93,7 @@ export const mailRouter = router({ let threadsResponse: IGetThreadsResponse; if (q) { - // When searching, leverage the driver's raw search for best accuracy + console.debug('[listThreads] Performing search with query:', q); threadsResponse = await agent.rawListThreads({ folder, query: q, @@ -97,10 +101,11 @@ export const mailRouter = router({ labelIds, pageToken: cursor, }); + console.debug('[listThreads] Search result:', threadsResponse); } else { - // Normal listing – include explicit folder label so that label filters work together const folderLabelId = getFolderLabelId(folder); const labelIdsToUse = folderLabelId ? [...labelIds, folderLabelId] : labelIds; + console.debug('[listThreads] Listing with labelIds:', labelIdsToUse, 'for folder:', folder); threadsResponse = await agent.listThreads({ folder, @@ -108,12 +113,15 @@ export const mailRouter = router({ maxResults, pageToken: cursor, }); + console.debug('[listThreads] List result:', threadsResponse); } if (folder === FOLDERS.SNOOZED) { const nowTs = Date.now(); const filtered: ThreadItem[] = []; + console.debug('[listThreads] Filtering snoozed threads at', new Date(nowTs).toISOString()); + await Promise.all( threadsResponse.threads.map(async (t: ThreadItem) => { const keyName = `${t.id}__${activeConnection.id}`; @@ -130,7 +138,7 @@ export const mailRouter = router({ return; } - console.log('[UNSNOOZE_ON_ACCESS] Expired thread', t.id, { + console.debug('[UNSNOOZE_ON_ACCESS] Expired thread', t.id, { wakeAtIso, now: new Date(nowTs).toISOString(), }); @@ -145,7 +153,9 @@ export const mailRouter = router({ ); threadsResponse.threads = filtered; + console.debug('[listThreads] Snoozed threads after filtering:', filtered); } + console.debug('[listThreads] Returning threadsResponse:', threadsResponse); return threadsResponse; }), markAsRead: activeDriverProcedure @@ -494,7 +504,9 @@ export const mailRouter = router({ await agent.modifyLabels(input.ids, ['INBOX'], ['SNOOZED']); await Promise.all( - input.ids.map((threadId) => env.snoozed_emails.delete(`${threadId}__${activeConnection.id}`)), + input.ids.map((threadId) => + env.snoozed_emails.delete(`${threadId}__${activeConnection.id}`), + ), ); return { success: true }; }), diff --git a/apps/server/src/trpc/routes/settings.ts b/apps/server/src/trpc/routes/settings.ts index afafa6319..547a0247b 100644 --- a/apps/server/src/trpc/routes/settings.ts +++ b/apps/server/src/trpc/routes/settings.ts @@ -15,7 +15,7 @@ export const settingsRouter = router({ if (!ctx.sessionUser) return { settings: defaultUserSettings }; const { sessionUser } = ctx; - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); const result: any = await db.findUserSettings(); // Returning null here when there are no settings so we can use the default settings with timezone from the browser @@ -33,7 +33,7 @@ export const settingsRouter = router({ save: privateProcedure.input(userSettingsSchema.partial()).mutation(async ({ ctx, input }) => { const { sessionUser } = ctx; - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); const existingSettings: any = await db.findUserSettings(); if (existingSettings) { diff --git a/apps/server/src/trpc/routes/shortcut.ts b/apps/server/src/trpc/routes/shortcut.ts index 1cf34e48e..c06efc0cf 100644 --- a/apps/server/src/trpc/routes/shortcut.ts +++ b/apps/server/src/trpc/routes/shortcut.ts @@ -13,7 +13,7 @@ export const shortcutRouter = router({ .mutation(async ({ ctx, input }) => { const { sessionUser } = ctx; const { shortcuts } = input; - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); await db.insertUserHotkeys(shortcuts as any); }), }); diff --git a/apps/server/src/trpc/trpc.ts b/apps/server/src/trpc/trpc.ts index 5426e18ca..5f60481be 100644 --- a/apps/server/src/trpc/trpc.ts +++ b/apps/server/src/trpc/trpc.ts @@ -55,7 +55,7 @@ export const activeDriverProcedure = activeConnectionProcedure.use(async ({ ctx, if (!res.ok && res.error.message === 'invalid_grant') { // Remove the access token and refresh token - const db = getZeroDB(sessionUser.id); + const db = await getZeroDB(sessionUser.id); await db.updateConnection(activeConnection.id, { accessToken: null, refreshToken: null,