Просмотр исходного кода

Merge pull request #9710 from weseek/feat/162669-extend-thread-expiration-on-message-creation

feat: Extend thread expiration on message creation
Shun Miyazawa 1 год назад
Родитель
Сommit
2cca7c30c6

+ 6 - 1
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -125,11 +125,16 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
     let currentThreadId_ = currentThreadId;
     if (currentThreadId_ == null) {
       try {
-        const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', { aiAssistantId: aiAssistantData._id, initialUserMessage: newUserMessage.content });
+        const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
+          aiAssistantId: aiAssistantData._id,
+          initialUserMessage: newUserMessage.content,
+        });
+
         const thread = res.data;
 
         setCurrentThreadId(thread.threadId);
         setCurrentThreadTitle(thread.title);
+
         currentThreadId_ = thread.threadId;
 
         // No need to await because data is not used

+ 8 - 0
apps/app/src/features/openai/server/routes/message.ts

@@ -16,6 +16,7 @@ import loggerFactory from '~/utils/logger';
 import { shouldHideMessageKey } from '../../interfaces/message';
 import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
 import AiAssistantModel from '../models/ai-assistant';
+import ThreadRelationModel from '../models/thread-relation';
 import { openaiClient } from '../services/client';
 import { getStreamErrorCode } from '../services/getStreamErrorCode';
 import { getOpenaiService } from '../services/openai';
@@ -76,6 +77,13 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
       }
 
+      const threadRelation = await ThreadRelationModel.findOne({ threadId });
+      if (threadRelation == null) {
+        return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
+      }
+
+      threadRelation.updateThreadExpiration();
+
       let stream: AssistantStream;
 
       try {

+ 4 - 8
apps/app/src/features/openai/server/routes/thread.ts

@@ -3,7 +3,6 @@ import { ErrorV3 } from '@growi/core/dist/models';
 import type { Request, RequestHandler } from 'express';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
-import { filterXSS } from 'xss';
 
 import type Crowi from '~/server/crowi';
 import { accessTokenParser } from '~/server/middlewares/access-token-parser';
@@ -19,8 +18,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
 type ReqBody = {
   aiAssistantId: string,
-  threadId?: string,
-  initialUserMessage?: string,
+  initialUserMessage: string,
 }
 
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
@@ -32,8 +30,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
 
   const validator: ValidationChain[] = [
     body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
-    body('threadId').optional().isString().withMessage('threadId must be string'),
-    body('initialUserMessage').optional().isString().withMessage('initialUserMessage must be string'),
+    body('initialUserMessage').isString().withMessage('initialUserMessage must be string'),
   ];
 
   return [
@@ -46,17 +43,16 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
 
       try {
-        const { aiAssistantId, threadId, initialUserMessage } = req.body;
+        const { aiAssistantId, initialUserMessage } = req.body;
 
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         if (!isAiAssistantUsable) {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
         }
 
-        const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+        const thread = await openaiService.createThread(req.user._id, vectorStoreRelation, initialUserMessage);
 
-        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage);
         return res.apiv3(thread);
       }
       catch (err) {

+ 16 - 42
apps/app/src/features/openai/server/services/openai.ts

@@ -65,11 +65,10 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 
 export interface IOpenaiService {
-  getOrCreateThread(
-    userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string
+  createThread(
+    userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string
   ): Promise<ThreadRelationDocument>;
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
-  // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
@@ -83,8 +82,6 @@ export interface IOpenaiService {
   deleteVectorStoreFile(vectorStoreRelationId: Types.ObjectId, pageId: Types.ObjectId): Promise<void>;
   deleteVectorStoreFilesByPageIds(pageIds: Types.ObjectId[]): Promise<void>;
   deleteObsoleteVectorStoreFile(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
-  // rebuildVectorStoreAll(): Promise<void>;
-  // rebuildVectorStore(page: HydratedDocument<PageDocument>): Promise<void>;
   isAiAssistantUsable(aiAssistantId: string, user: IUserHasId): Promise<boolean>;
   createAiAssistant(data: Omit<AiAssistant, 'vectorStore'>): Promise<AiAssistantDocument>;
   updateAiAssistant(aiAssistantId: string, data: Omit<AiAssistant, 'vectorStore'>): Promise<AiAssistantDocument>;
@@ -125,53 +122,29 @@ class OpenaiService implements IOpenaiService {
     return threadTitle;
   }
 
-  async getOrCreateThread(
-      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string,
-  ): Promise<ThreadRelationDocument> {
-    if (threadId == null) {
-      let threadTitle: string | null = null;
-      if (initialUserMessage != null) {
-        try {
-          threadTitle = await this.generateThreadTitle(initialUserMessage);
-        }
-        catch (err) {
-          logger.error(err);
-        }
-      }
-
+  async createThread(userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string): Promise<ThreadRelationDocument> {
+    let threadTitle: string | null = null;
+    if (initialUserMessage != null) {
       try {
-        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
-        const threadRelation = await ThreadRelationModel.create({
-          userId,
-          threadId: thread.id,
-          vectorStore: vectorStoreRelation._id,
-          title: threadTitle,
-        });
-        return threadRelation;
+        threadTitle = await this.generateThreadTitle(initialUserMessage);
       }
       catch (err) {
-        throw new Error(err);
+        logger.error(err);
       }
     }
 
-    const threadRelation = await ThreadRelationModel.findOne({ threadId });
-    if (threadRelation == null) {
-      throw new Error('ThreadRelation document is not exists');
-    }
-
-    // Check if a thread entity exists
-    // If the thread entity does not exist, the thread-relation document is deleted
     try {
-      const thread = await this.client.retrieveThread(threadRelation.threadId);
-
-      // Update expiration date if thread entity exists
-      await threadRelation.updateThreadExpiration();
-
+      const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+      const threadRelation = await ThreadRelationModel.create({
+        userId,
+        threadId: thread.id,
+        vectorStore: vectorStoreRelation._id,
+        title: threadTitle,
+      });
       return threadRelation;
     }
     catch (err) {
-      await openaiApiErrorHandler(err, { notFoundError: async() => { await threadRelation.remove() } });
-      throw new Error(err);
+      throw err;
     }
   }
 
@@ -192,6 +165,7 @@ class OpenaiService implements IOpenaiService {
       await threadRelation.remove();
     }
     catch (err) {
+      await openaiApiErrorHandler(err, { notFoundError: async() => { await threadRelation.remove() } });
       throw err;
     }