Просмотр исходного кода

Merge pull request #9716 from weseek/fix/162726-threads-disappear-on-vectorstore-rebuild

fix: Threads disappear on vectorstore rebuild
Yuki Takei 1 год назад
Родитель
Сommit
a20c149d2b

+ 2 - 2
apps/app/src/features/openai/interfaces/thread-relation.ts

@@ -1,10 +1,10 @@
 import type { IUser, Ref, HasObjectId } from '@growi/core';
 import type { IUser, Ref, HasObjectId } from '@growi/core';
 
 
-import type { IVectorStore } from './vector-store';
+import type { AiAssistant } from './ai-assistant';
 
 
 export interface IThreadRelation {
 export interface IThreadRelation {
   userId: Ref<IUser>
   userId: Ref<IUser>
-  vectorStore: Ref<IVectorStore>
+  aiAssistant: Ref<AiAssistant>
   threadId: string;
   threadId: string;
   title?: string;
   title?: string;
   expiredAt: Date;
   expiredAt: Date;

+ 2 - 2
apps/app/src/features/openai/server/models/thread-relation.ts

@@ -25,9 +25,9 @@ const schema = new Schema<ThreadRelationDocument, ThreadRelationModel>({
     ref: 'User',
     ref: 'User',
     required: true,
     required: true,
   },
   },
-  vectorStore: {
+  aiAssistant: {
     type: Schema.Types.ObjectId,
     type: Schema.Types.ObjectId,
-    ref: 'VectorStore',
+    ref: 'AiAssistant',
     required: true,
     required: true,
   },
   },
   threadId: {
   threadId: {

+ 1 - 2
apps/app/src/features/openai/server/routes/get-threads.ts

@@ -48,8 +48,7 @@ export const getThreadsFactory: GetThreadsFactory = (crowi) => {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
         }
         }
 
 
-        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
-        const threads = await openaiService.getThreads(vectorStoreRelation._id);
+        const threads = await openaiService.getThreadsByAiAssistantId(aiAssistantId);
 
 
         return res.apiv3({ threads });
         return res.apiv3({ threads });
       }
       }

+ 1 - 2
apps/app/src/features/openai/server/routes/thread.ts

@@ -50,8 +50,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
           return res.apiv3Err(new ErrorV3('The specified AI assistant is not usable'), 400);
         }
         }
 
 
-        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
-        const thread = await openaiService.createThread(req.user._id, vectorStoreRelation, initialUserMessage);
+        const thread = await openaiService.createThread(req.user._id, aiAssistantId, initialUserMessage);
 
 
         return res.apiv3(thread);
         return res.apiv3(thread);
       }
       }

+ 10 - 0
apps/app/src/features/openai/server/services/client-delegator/azure-openai-client-delegator.ts

@@ -33,6 +33,16 @@ export class AzureOpenaiClientDelegator implements IOpenaiClientDelegator {
     });
     });
   }
   }
 
 
+  async updateThread(threadId: string, vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.update(threadId, {
+      tool_resources: {
+        file_search: {
+          vector_store_ids: [vectorStoreId],
+        },
+      },
+    });
+  }
+
   async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
   async retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread> {
     return this.client.beta.threads.retrieve(threadId);
     return this.client.beta.threads.retrieve(threadId);
   }
   }

+ 1 - 0
apps/app/src/features/openai/server/services/client-delegator/interfaces.ts

@@ -5,6 +5,7 @@ import type { MessageListParams } from '../../../interfaces/message';
 
 
 export interface IOpenaiClientDelegator {
 export interface IOpenaiClientDelegator {
   createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
   createThread(vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
+  updateThread(threadId: string, vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   retrieveThread(threadId: string): Promise<OpenAI.Beta.Threads.Thread>
   deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>
   deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted>
   getMessages(threadId: string, options?: MessageListParams): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>
   getMessages(threadId: string, options?: MessageListParams): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>

+ 10 - 0
apps/app/src/features/openai/server/services/client-delegator/openai-client-delegator.ts

@@ -38,6 +38,16 @@ export class OpenaiClientDelegator implements IOpenaiClientDelegator {
     return this.client.beta.threads.retrieve(threadId);
     return this.client.beta.threads.retrieve(threadId);
   }
   }
 
 
+  async updateThread(threadId: string, vectorStoreId: string): Promise<OpenAI.Beta.Threads.Thread> {
+    return this.client.beta.threads.update(threadId, {
+      tool_resources: {
+        file_search: {
+          vector_store_ids: [vectorStoreId],
+        },
+      },
+    });
+  }
+
   async deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted> {
   async deleteThread(threadId: string): Promise<OpenAI.Beta.Threads.ThreadDeleted> {
     return this.client.beta.threads.del(threadId);
     return this.client.beta.threads.del(threadId);
   }
   }

+ 3 - 3
apps/app/src/features/openai/server/services/normalize-data/normalize-thread-relation-expired-at/normalize-thread-relation-expired-at.integ.ts

@@ -15,7 +15,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       userId: new Types.ObjectId(),
       threadId: 'test-thread',
       threadId: 'test-thread',
-      vectorStore: new Types.ObjectId(),
+      aiAssistant: new Types.ObjectId(),
       expiredAt: expiredDate,
       expiredAt: expiredDate,
     });
     });
     await threadRelation.save();
     await threadRelation.save();
@@ -37,7 +37,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       userId: new Types.ObjectId(),
       threadId: 'test-thread-2',
       threadId: 'test-thread-2',
-      vectorStore: new Types.ObjectId(),
+      aiAssistant: new Types.ObjectId(),
       expiredAt: nonExpiredDate,
       expiredAt: nonExpiredDate,
     });
     });
     await threadRelation.save();
     await threadRelation.save();
@@ -57,7 +57,7 @@ describe('normalizeExpiredAtForThreadRelations', () => {
     const threadRelation = new ThreadRelation({
     const threadRelation = new ThreadRelation({
       userId: new Types.ObjectId(),
       userId: new Types.ObjectId(),
       threadId: 'test-thread-3',
       threadId: 'test-thread-3',
-      vectorStore: new Types.ObjectId(),
+      aiAssistant: new Types.ObjectId(),
       expiredAt: nonExpiredDate,
       expiredAt: nonExpiredDate,
     });
     });
     await threadRelation.save();
     await threadRelation.save();

+ 24 - 11
apps/app/src/features/openai/server/services/openai.ts

@@ -65,17 +65,13 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 };
 };
 
 
 export interface IOpenaiService {
 export interface IOpenaiService {
-  createThread(
-    userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string
-  ): Promise<ThreadRelationDocument>;
-  getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]>
+  createThread(userId: string, aiAssistantId: string, initialUserMessage: string): Promise<ThreadRelationDocument>;
+  getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]>
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
   deleteVectorStore(vectorStoreRelationId: string): Promise<void>;
   deleteVectorStore(vectorStoreRelationId: string): Promise<void>;
   getMessageData(threadId: string, lang?: Lang, options?: MessageListParams): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>;
   getMessageData(threadId: string, lang?: Lang, options?: MessageListParams): Promise<OpenAI.Beta.Threads.Messages.MessagesPage>;
-  getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument>
-  getVectorStoreRelationsByPageIds(pageId: Types.ObjectId[]): Promise<VectorStoreDocument[]>;
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
   createVectorStoreFileOnPageCreate(pages: PageDocument[]): Promise<void>;
   createVectorStoreFileOnPageCreate(pages: PageDocument[]): Promise<void>;
   updateVectorStoreFileOnPageUpdate(page: HydratedDocument<PageDocument>): Promise<void>;
   updateVectorStoreFileOnPageUpdate(page: HydratedDocument<PageDocument>): Promise<void>;
@@ -122,7 +118,9 @@ class OpenaiService implements IOpenaiService {
     return threadTitle;
     return threadTitle;
   }
   }
 
 
-  async createThread(userId: string, vectorStoreRelation: VectorStoreDocument, initialUserMessage: string): Promise<ThreadRelationDocument> {
+  async createThread(userId: string, aiAssistantId: string, initialUserMessage: string): Promise<ThreadRelationDocument> {
+    const vectorStoreRelation = await this.getVectorStoreRelationByAiAssistantId(aiAssistantId);
+
     let threadTitle: string | null = null;
     let threadTitle: string | null = null;
     if (initialUserMessage != null) {
     if (initialUserMessage != null) {
       try {
       try {
@@ -137,8 +135,8 @@ class OpenaiService implements IOpenaiService {
       const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
       const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
       const threadRelation = await ThreadRelationModel.create({
       const threadRelation = await ThreadRelationModel.create({
         userId,
         userId,
+        aiAssistant: aiAssistantId,
         threadId: thread.id,
         threadId: thread.id,
-        vectorStore: vectorStoreRelation._id,
         title: threadTitle,
         title: threadTitle,
       });
       });
       return threadRelation;
       return threadRelation;
@@ -148,8 +146,21 @@ class OpenaiService implements IOpenaiService {
     }
     }
   }
   }
 
 
-  async getThreads(vectorStoreRelationId: string): Promise<ThreadRelationDocument[]> {
-    const threadRelations = await ThreadRelationModel.find({ vectorStore: vectorStoreRelationId });
+  async updateThreads(aiAssistantId: string, vectorStoreId: string): Promise<void> {
+    const threadRelations = await this.getThreadsByAiAssistantId(aiAssistantId);
+    for await (const threadRelation of threadRelations) {
+      try {
+        const updatedThreadResponse = await this.client.updateThread(threadRelation.threadId, vectorStoreId);
+        logger.debug('Update thread', updatedThreadResponse);
+      }
+      catch (err) {
+        logger.error(err);
+      }
+    }
+  }
+
+  async getThreadsByAiAssistantId(aiAssistantId: string): Promise<ThreadRelationDocument[]> {
+    const threadRelations = await ThreadRelationModel.find({ aiAssistant: aiAssistantId });
     return threadRelations;
     return threadRelations;
   }
   }
 
 
@@ -211,7 +222,7 @@ class OpenaiService implements IOpenaiService {
   }
   }
 
 
 
 
-  async getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument> {
+  async getVectorStoreRelationByAiAssistantId(aiAssistantId: string): Promise<VectorStoreDocument> {
     const aiAssistant = await AiAssistantModel.findById({ _id: aiAssistantId }).populate('vectorStore');
     const aiAssistant = await AiAssistantModel.findById({ _id: aiAssistantId }).populate('vectorStore');
     if (aiAssistant == null) {
     if (aiAssistant == null) {
       throw createError(404, 'AiAssistant document does not exist');
       throw createError(404, 'AiAssistant document does not exist');
@@ -812,6 +823,8 @@ class OpenaiService implements IOpenaiService {
 
 
       newVectorStoreRelation = await this.createVectorStore(data.name);
       newVectorStoreRelation = await this.createVectorStore(data.name);
 
 
+      this.updateThreads(aiAssistantId, newVectorStoreRelation.vectorStoreId);
+
       // VectorStore creation process does not await
       // VectorStore creation process does not await
       this.createVectorStoreFileWithStream(newVectorStoreRelation, conditions);
       this.createVectorStoreFileWithStream(newVectorStoreRelation, conditions);
     }
     }