From 08f3c64e1a4edc0216a42fc1d1180bea6c65ab50 Mon Sep 17 00:00:00 2001 From: bertybuttface <110790513+bertybuttface@users.noreply.github.com> Date: Fri, 27 Jan 2023 17:09:54 +0000 Subject: [PATCH] Add support for configurable context levels. --- README.md | 3 +++ src/env.ts | 4 +++- src/handlers.ts | 32 ++++++++++++++++++++++---------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 830dcbc..73420c4 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,9 @@ OPENAI_LOGIN_TYPE=google # Set the next line to true if you are using a ChatGPT pro account. OPENAI_PRO=false +# Set the ChatGPT conversation context to 'thread', 'room' or 'both'. +CHATGPT_CONTEXT=thread + # Matrix Static Settings (required, see notes) # Defaults to "https://matrix.org" MATRIX_HOMESERVER_URL= diff --git a/src/env.ts b/src/env.ts index eeca625..8f4b0d3 100644 --- a/src/env.ts +++ b/src/env.ts @@ -26,6 +26,7 @@ export const { OPENAI_PASSWORD, OPENAI_LOGIN_TYPE, OPENAI_PRO, + CHATGPT_CONTEXT, CHATGPT_TIMEOUT } = parseEnv(process.env, { DATA_PATH: { schema: z.string().default("./storage"), description: "Set to /storage/ if using docker, ./storage if running without" }, @@ -50,5 +51,6 @@ export const { OPENAI_PASSWORD: { schema: z.string().min(1), description: "Set password of OpenAI's account" }, OPENAI_LOGIN_TYPE: { schema: z.enum(["google", "openai", "microsoft"]).default("google"), description: "Set authentication provider to 'google', 'openai' or 'microsoft'" }, OPENAI_PRO: { schema: z.boolean().default(false), description: "Set to true if you have a paid ChatGPT subscription." }, - CHATGPT_TIMEOUT: { schema: z.number().default(2 * 60 * 1000), description: "Set number of milliseconds to wait for CHATGPT responses" } + CHATGPT_TIMEOUT: { schema: z.number().default(2 * 60 * 1000), description: "Set number of milliseconds to wait for ChatGPT responses" }, + CHATGPT_CONTEXT: { schema: z.enum(["thread", "room", "both"]).default("thread"), description: "Set the ChatGPT conversation context to 'thread', 'room' or 'both'" } }); diff --git a/src/handlers.ts b/src/handlers.ts index b7f6d4e..635054f 100644 --- a/src/handlers.ts +++ b/src/handlers.ts @@ -1,6 +1,6 @@ import { ChatGPTAPIBrowser } from "chatgpt"; import { LogService, MatrixClient, UserID } from "matrix-bot-sdk"; -import { CHATGPT_TIMEOUT, MATRIX_DEFAULT_PREFIX_REPLY, MATRIX_DEFAULT_PREFIX, MATRIX_BLACKLIST, MATRIX_WHITELIST, MATRIX_RICH_TEXT, MATRIX_PREFIX_DM } from "./env.js"; +import { CHATGPT_CONTEXT, CHATGPT_TIMEOUT, MATRIX_DEFAULT_PREFIX_REPLY, MATRIX_DEFAULT_PREFIX, MATRIX_BLACKLIST, MATRIX_WHITELIST, MATRIX_RICH_TEXT, MATRIX_PREFIX_DM } from "./env.js"; import { RelatesTo, MessageEvent, StoredConversation, StoredConversationConfig } from "./interfaces.js"; import { sendChatGPTMessage, sendError, sendThreadReply } from "./utils.js"; @@ -44,8 +44,20 @@ export default class CommandHandler { return (!isReplyOrThread && relatesTo.event_id !== undefined) ? relatesTo.event_id : event.event_id; } - private async getStoredConversation(rootEventId: string): Promise { - const storedValue: string = await this.client.storageProvider.readValue('gpt-' + rootEventId) + private getStorageKey(event: MessageEvent, roomId: string): string { + const rootEventId: string = this.getRootEventId(event) + if (CHATGPT_CONTEXT == "room") { + return roomId + } else if (CHATGPT_CONTEXT == "thread") { + return rootEventId + } else { // CHATGPT_CONTEXT set to both. + return (rootEventId !== event.event_id) ? rootEventId : roomId; + } + } + + private async getStoredConversation(storageKey: string, roomId: string): Promise { + let storedValue: string = await this.client.storageProvider.readValue('gpt-' + storageKey) + if (storedValue == undefined && storageKey != roomId) storedValue = await this.client.storageProvider.readValue('gpt-' + roomId) return (storedValue !== undefined) ? JSON.parse(storedValue) : undefined; } @@ -91,8 +103,8 @@ export default class CommandHandler { try { if (this.shouldIgnore(event)) return; - const rootEventId = this.getRootEventId(event); - const storedConversation = await this.getStoredConversation(rootEventId); + const storageKey = this.getStorageKey(event, roomId); + const storedConversation = await this.getStoredConversation(storageKey, roomId); const config = this.getConfig(storedConversation); const shouldBePrefixed = await this.shouldBePrefixed(config, roomId, event) @@ -112,13 +124,13 @@ export default class CommandHandler { const result = await sendChatGPTMessage(this.chatGPT, await bodyWithoutPrefix, storedConversation); await Promise.all([ this.client.setTyping(roomId, false, 500), - sendThreadReply(this.client, roomId, rootEventId, `${result.response}`, MATRIX_RICH_TEXT) + sendThreadReply(this.client, roomId, this.getRootEventId(event), `${result.response}`, MATRIX_RICH_TEXT) ]); - await this.client.storageProvider.storeValue('gpt-' + rootEventId, JSON.stringify({ - conversationId: result.conversationId, messageId: result.messageId, - config: ((storedConversation !== undefined && storedConversation.config !== undefined) ? storedConversation.config : {}), - })); + const storedConfig = ((storedConversation !== undefined && storedConversation.config !== undefined) ? storedConversation.config : {}) + const configString: string = JSON.stringify({conversationId: result.conversationId, messageId: result.messageId, config: storedConfig}) + await this.client.storageProvider.storeValue('gpt-' + storageKey, configString); + if ((storageKey === roomId) && (CHATGPT_CONTEXT === "both")) await this.client.storageProvider.storeValue('gpt-' + event.event_id, configString); } catch (err) { console.error(err); }