Просмотр исходного кода

Add logic to auto-generate thread title

Shun Miyazawa 1 год назад
Родитель
Сommit
004b89a466

+ 1 - 1
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -122,7 +122,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
     let currentThreadId_ = currentThreadId;
     if (currentThreadId_ == null) {
       try {
-        const res = await apiv3Post('/openai/thread', { aiAssistantId: aiAssistantData._id });
+        const res = await apiv3Post('/openai/thread', { aiAssistantId: aiAssistantData._id, initialUserMessage: newUserMessage.content });
         const thread = res.data.thread;
 
         setCurrentThreadId(thread.id);

+ 4 - 2
apps/app/src/features/openai/server/routes/thread.ts

@@ -20,6 +20,7 @@ const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 type ReqBody = {
   aiAssistantId: string,
   threadId?: string,
+  initialUserMessage?: string,
 }
 
 type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
@@ -32,6 +33,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const validator: ValidationChain[] = [
     body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
+    body('initialUserMessage').optional().isString().withMessage('initialUserMessage must be string'),
   ];
 
   return [
@@ -44,7 +46,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
 
       try {
-        const { aiAssistantId, threadId } = req.body;
+        const { aiAssistantId, threadId, initialUserMessage } = req.body;
 
         const isAiAssistantUsable = await openaiService.isAiAssistantUsable(aiAssistantId, req.user);
         if (!isAiAssistantUsable) {
@@ -54,7 +56,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
         const filteredThreadId = threadId != null ? filterXSS(threadId) : undefined;
         const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
 
-        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId);
+        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filteredThreadId, initialUserMessage);
         return res.apiv3({ thread });
       }
       catch (err) {

+ 47 - 3
apps/app/src/features/openai/server/services/openai.ts

@@ -63,7 +63,9 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 
 
 export interface IOpenaiService {
-  getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getOrCreateThread(
+    userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string
+  ): Promise<OpenAI.Beta.Threads.Thread | undefined>;
   getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
@@ -93,11 +95,53 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
   }
 
-  public async getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+  async generateThreadTitle(message: string): Promise<string | null> {
+    const systemMessage = [
+      'Create a brief title (max 5 words) from your message.',
+      'Response should only contain the title.',
+    ].join('');
+
+    const threadTitleCompletion = await this.client.chatCompletion({
+      model: 'gpt-4o-mini',
+      messages: [
+        {
+          role: 'system',
+          content: systemMessage,
+        },
+        {
+          role: 'user',
+          content: message,
+        },
+      ],
+    });
+
+    const threadTitle = threadTitleCompletion.choices[0].message.content;
+    return threadTitle;
+  }
+
+  async getOrCreateThread(
+      userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string, initialUserMessage?: string,
+  ): Promise<OpenAI.Beta.Threads.Thread> {
     if (threadId == null) {
+      let threadTitle: string | null = null;
+      if (initialUserMessage != null) {
+        try {
+          threadTitle = await this.generateThreadTitle(initialUserMessage);
+        }
+        catch (err) {
+          logger.error(err);
+        }
+      }
+
       try {
         const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
-        await ThreadRelationModel.create({ userId, threadId: thread.id, vectorStore: vectorStoreRelation._id });
+        await ThreadRelationModel.create({
+          userId,
+          threadId: thread.id,
+          vectorStore: vectorStoreRelation._id,
+          title: threadTitle,
+        });
+
         return thread;
       }
       catch (err) {