Jelajahi Sumber

Associate a VectorStore when creating a thread

Shun Miyazawa 1 tahun lalu
induk
melakukan
fd60a5e2cf

+ 1 - 1
apps/app/src/features/openai/client/components/AiAssistant/AiAssistantChatSidebar/AiAssistantChatSidebar.tsx

@@ -96,7 +96,7 @@ const AiAssistantChatSidebarSubstance: React.FC<AiAssistantChatSidebarSubstanceP
     let currentThreadId = threadId;
     let currentThreadId = threadId;
     if (threadId == null) {
     if (threadId == null) {
       try {
       try {
-        const res = await apiv3Post('/openai/thread');
+        const res = await apiv3Post('/openai/thread', { aiAssistantId: aiAssistantData?._id });
         const thread = res.data.thread;
         const thread = res.data.thread;
 
 
         setThreadId(thread.id);
         setThreadId(thread.id);

+ 16 - 5
apps/app/src/features/openai/server/routes/thread.ts

@@ -17,7 +17,12 @@ import { certifyAiService } from './middlewares/certify-ai-service';
 
 
 const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 const logger = loggerFactory('growi:routes:apiv3:openai:thread');
 
 
-type CreateThreadReq = Request<undefined, ApiV3Response, { threadId?: string }> & { user: IUserHasId };
+type ReqBody = {
+  aiAssistantId: string,
+  threadId?: string,
+}
+
+type CreateThreadReq = Request<undefined, ApiV3Response, ReqBody> & { user: IUserHasId };
 
 
 type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
 type CreateThreadFactory = (crowi: Crowi) => RequestHandler[];
 
 
@@ -25,6 +30,7 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
   const loginRequiredStrictly = require('~/server/middlewares/login-required')(crowi);
 
 
   const validator: ValidationChain[] = [
   const validator: ValidationChain[] = [
+    body('aiAssistantId').isMongoId().withMessage('aiAssistantId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
     body('threadId').optional().isString().withMessage('threadId must be string'),
   ];
   ];
 
 
@@ -38,10 +44,15 @@ export const createThreadHandlersFactory: CreateThreadFactory = (crowi) => {
       }
       }
 
 
       try {
       try {
-        const filterdThreadId = req.body.threadId != null ? filterXSS(req.body.threadId) : undefined;
-        // const vectorStore = await openaiService?.getOrCreateVectorStoreForPublicScope();
-        // const thread = await openaiService?.getOrCreateThread(req.user._id, vectorStore?.vectorStoreId, filterdThreadId);
-        return res.apiv3({ });
+        const { aiAssistantId, threadId } = req.body;
+
+        // リクエストした user が AiAssistant の owner or shareScope に含まれているかチェックする
+        const vectorStoreRelation = await openaiService.getVectorStoreRelation(aiAssistantId);
+
+        const filterdThreadId = threadId != null ? filterXSS(threadId) : undefined;
+
+        const thread = await openaiService.getOrCreateThread(req.user._id, vectorStoreRelation, filterdThreadId);
+        return res.apiv3({ thread });
       }
       }
       catch (err) {
       catch (err) {
         logger.error(err);
         logger.error(err);

+ 15 - 5
apps/app/src/features/openai/server/services/openai.ts

@@ -61,10 +61,11 @@ const convertPathPatternsToRegExp = (pagePathPatterns: string[]): Array<string |
 
 
 
 
 export interface IOpenaiService {
 export interface IOpenaiService {
-  getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
+  getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread | undefined>;
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   // getOrCreateVectorStoreForPublicScope(): Promise<VectorStoreDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
   deleteObsolatedVectorStoreRelations(): Promise<void> // for CronJob
+  getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument>
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
   createVectorStoreFile(vectorStoreRelation: VectorStoreDocument, pages: PageDocument[]): Promise<void>;
   deleteVectorStoreFile(vectorStoreRelationId: Types.ObjectId, pageId: Types.ObjectId): Promise<void>;
   deleteVectorStoreFile(vectorStoreRelationId: Types.ObjectId, pageId: Types.ObjectId): Promise<void>;
   deleteObsoleteVectorStoreFile(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsoleteVectorStoreFile(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
@@ -82,11 +83,11 @@ class OpenaiService implements IOpenaiService {
     return getClient({ openaiServiceType });
     return getClient({ openaiServiceType });
   }
   }
 
 
-  public async getOrCreateThread(userId: string, vectorStoreId?: string, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
-    if (vectorStoreId != null && threadId == null) {
+  public async getOrCreateThread(userId: string, vectorStoreRelation: VectorStoreDocument, threadId?: string): Promise<OpenAI.Beta.Threads.Thread> {
+    if (threadId == null) {
       try {
       try {
-        const thread = await this.client.createThread(vectorStoreId);
-        await ThreadRelationModel.create({ userId, threadId: thread.id });
+        const thread = await this.client.createThread(vectorStoreRelation.vectorStoreId);
+        await ThreadRelationModel.create({ userId, threadId: thread.id, vectorStore: vectorStoreRelation._id });
         return thread;
         return thread;
       }
       }
       catch (err) {
       catch (err) {
@@ -172,6 +173,15 @@ class OpenaiService implements IOpenaiService {
   //   return newVectorStoreDocument;
   //   return newVectorStoreDocument;
   // }
   // }
 
 
+  async getVectorStoreRelation(aiAssistantId: string): Promise<VectorStoreDocument> {
+    const aiAssistant = await AiAssistantModel.findById({ _id: aiAssistantId }).populate('vectorStore');
+    if (aiAssistant == null) {
+      throw createError(404, 'AiAssistant document does not exist');
+    }
+
+    return aiAssistant.vectorStore as VectorStoreDocument;
+  }
+
   private async createVectorStore(name: string): Promise<VectorStoreDocument> {
   private async createVectorStore(name: string): Promise<VectorStoreDocument> {
     try {
     try {
       const newVectorStore = await this.client.createVectorStore(name);
       const newVectorStore = await this.client.createVectorStore(name);