Просмотр исходного кода

Extend thread expiration time when composing a message

Shun Miyazawa 1 год назад
Родитель
Сommit
d1d9e2faef

+ 2 - 5
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -53,7 +53,6 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
     aiAssistantData, threadData, closeAiAssistantChatSidebar,
     aiAssistantData, threadData, closeAiAssistantChatSidebar,
   } = props;
   } = props;
 
 
-  const [isThreadVerified, setIsThreadVerified] = useState<boolean>(false);
   const [currentThreadTitle, setCurrentThreadTitle] = useState<string | undefined>(threadData?.title);
   const [currentThreadTitle, setCurrentThreadTitle] = useState<string | undefined>(threadData?.title);
   const [currentThreadId, setCurrentThreadId] = useState<string | undefined>(threadData?.threadId);
   const [currentThreadId, setCurrentThreadId] = useState<string | undefined>(threadData?.threadId);
   const [messageLogs, setMessageLogs] = useState<Message[]>([]);
   const [messageLogs, setMessageLogs] = useState<Message[]>([]);
@@ -124,17 +123,15 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
 
 
     // create thread
     // create thread
     let currentThreadId_ = currentThreadId;
     let currentThreadId_ = currentThreadId;
-    if (!isThreadVerified || currentThreadId_ == null) {
+    if (currentThreadId_ == null) {
       try {
       try {
         const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
         const res = await apiv3Post<IThreadRelationHasId>('/openai/thread', {
-          threadId: currentThreadId_,
           aiAssistantId: aiAssistantData._id,
           aiAssistantId: aiAssistantData._id,
           initialUserMessage: newUserMessage.content,
           initialUserMessage: newUserMessage.content,
         });
         });
 
 
         const thread = res.data;
         const thread = res.data;
 
 
-        setIsThreadVerified(true);
         setCurrentThreadId(thread.threadId);
         setCurrentThreadId(thread.threadId);
         setCurrentThreadTitle(thread.title);
         setCurrentThreadTitle(thread.title);
 
 
@@ -233,7 +230,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
       form.setError('input', { type: 'manual', message: err.toString() });
       form.setError('input', { type: 'manual', message: err.toString() });
     }
     }
 
 
-  }, [isGenerating, messageLogs, form, currentThreadId, isThreadVerified, aiAssistantData._id, mutateThreadData, t, growiCloudUri]);
+  }, [isGenerating, messageLogs, form, currentThreadId, aiAssistantData._id, mutateThreadData, t, growiCloudUri]);
 
 
   const keyDownHandler = (event: KeyboardEvent<HTMLTextAreaElement>) => {
   const keyDownHandler = (event: KeyboardEvent<HTMLTextAreaElement>) => {
     if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {
     if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {

+ 8 - 0
apps/app/src/features/openai/server/routes/message.ts

@@ -16,6 +16,7 @@ import loggerFactory from '~/utils/logger';
 import { shouldHideMessageKey } from '../../interfaces/message';
 import { shouldHideMessageKey } from '../../interfaces/message';
 import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
 import { MessageErrorCode, type StreamErrorCode } from '../../interfaces/message-error';
 import AiAssistantModel from '../models/ai-assistant';
 import AiAssistantModel from '../models/ai-assistant';
+import ThreadRelationModel from '../models/thread-relation';
 import { openaiClient } from '../services/client';
 import { openaiClient } from '../services/client';
 import { getStreamErrorCode } from '../services/getStreamErrorCode';
 import { getStreamErrorCode } from '../services/getStreamErrorCode';
 import { getOpenaiService } from '../services/openai';
 import { getOpenaiService } from '../services/openai';
@@ -76,6 +77,13 @@ export const postMessageHandlersFactory: PostMessageHandlersFactory = (crowi) =>
         return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
         return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
       }
       }
 
 
+      const thread = await ThreadRelationModel.findOne({ threadId });
+      if (thread == null) {
+        return res.apiv3Err(new ErrorV3('Thread not found'), 404);
+      }
+
+      thread.updateThreadExpiration();
+
       let stream: AssistantStream;
       let stream: AssistantStream;
 
 
       try {
       try {

+ 4 - 8
apps/app/src/features/openai/server/routes/thread.ts

@@ -3,7 +3,6 @@ import { ErrorV3 } from '@growi/core/dist/models';
 import type { Request, RequestHandler } from 'express';
 import type { Request, RequestHandler } from 'express';
 import type { ValidationChain } from 'express-validator';
 import type { ValidationChain } from 'express-validator';
 import { body } from 'express-validator';
 import { body } from 'express-validator';
-import { filterXSS } from 'xss';
 
 
 import type Crowi from '~/server/crowi';
 import type Crowi from '~/server/crowi';
 import { accessTokenParser } from '~/server/middlewares/access-token-parser';
 import { accessTokenParser } from '~/server/middlewares/access-token-parser';
@@ -19,8 +18,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
 
 type ReqBody = {
 type ReqBody = {
   aiAssistantId: string,
   aiAssistantId: string,
-  threadId?: string,
-  initialUserMessage?: string,
+  initialUserMessage: string,
 }
 }
 
 
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
@@ -32,8 +30,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
 
 
   const validator: ValidationChain[] = [
   const validator: ValidationChain[] = [
     body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
-    body('threadId').optional().isString().withMessage('threadId must be string'),
-    body('initialUserMessage').optional().isString().withMessage('initialUserMessage must be string'),
+    body('initialUserMessage').isString().withMessage('initialUserMessage must be string'),
   ];
   ];
 
 
   return [
   return [
@@ -46,17 +43,16 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
       }
 
 
       try {
       try {
-        const { aiAssistantId, threadId, initialUserMessage } = req.body;
+        const { aiAssistantId, initialUserMessage } = req.body;
 
 
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         if (!isAiAssistantUsable) {
         if (!isAiAssistantUsable) {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
         }
         }
 
 
-        const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+        const thread = await openaiService.createThread(req.user._id, vectorStoreRelation, initialUserMessage);
 
 
-        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage);
         return res.apiv3(thread);
         return res.apiv3(thread);
       }
       }
       catch (err) {
       catch (err) {

+ 15 - 39
apps/app/src/features/openai/server/services/openai.ts

@@ -63,8 +63,8 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 };
 
 
 export interface IOpenaiService {
 export interface IOpenaiService {
-  getOrCreateThread(
-    userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string
+  createThread(
+    userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string
   ): Promise<ThreadRelationDocument>;
   ): Promise<ThreadRelationDocument>;
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
@@ -122,53 +122,29 @@ class OpenaiService implements IOpenaiService {
     return threadTitle;
     return threadTitle;
   }
   }
 
 
-  async getOrCreateThread(
-      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string,
-  ): Promise<ThreadRelationDocument> {
-    if (threadId == null) {
-      let threadTitle: string | null = null;
-      if (initialUserMessage != null) {
-        try {
-          threadTitle = await this.generateThreadTitle(initialUserMessage);
-        }
-        catch (err) {
-          logger.error(err);
-        }
-      }
-
+  async createThread(userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string): Promise<ThreadRelationDocument> {
+    let threadTitle: string | null = null;
+    if (initialUserMessage != null) {
       try {
       try {
-        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
-        const threadRelation = await ThreadRelationModel.create({
-          userId,
-          threadId: thread.id,
-          vectorStore: vectorStoreRelation._id,
-          title: threadTitle,
-        });
-        return threadRelation;
+        threadTitle = await this.generateThreadTitle(initialUserMessage);
       }
       }
       catch (err) {
       catch (err) {
-        throw new Error(err);
+        logger.error(err);
       }
       }
     }
     }
 
 
-    const threadRelation = await ThreadRelationModel.findOne({ threadId });
-    if (threadRelation == null) {
-      throw new Error('ThreadRelation document is not exists');
-    }
-
-    // Check if a thread entity exists
-    // If the thread entity does not exist, the thread-relation document is deleted
     try {
     try {
-      await this.client.retrieveThread(threadRelation.threadId);
-
-      // Update expiration date if thread entity exists
-      await threadRelation.updateThreadExpiration();
-
+      const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+      const threadRelation = await ThreadRelationModel.create({
+        userId,
+        threadId: thread.id,
+        vectorStore: vectorStoreRelation._id,
+        title: threadTitle,
+      });
       return threadRelation;
       return threadRelation;
     }
     }
     catch (err) {
     catch (err) {
-      await openaiApiErrorHandler(err, { notFoundError: async() => { await threadRelation.remove() } });
-      throw new Error(err);
+      throw err;
     }
     }
   }
   }