ソースを参照

fix: openai thread idor

Ryotaro Nagahara 1 ヶ月 前
コミット
236549a47e

+ 4 - 2
apps/app/src/features/openai/server/routes/delete-thread.ts

@@ -68,8 +68,10 @@ export const deleteThreadFactory = (crowi: Crowi): RequestHandler[] => {
       }
 
       try {
-        const deletedThreadRelation =
-          await openaiService.deleteThread(threadRelationId);
+        const deletedThreadRelation = await openaiService.deleteThread(
+          threadRelationId,
+          user._id,
+        );
         return res.apiv3({ deletedThreadRelation });
       } catch (err) {
         logger.error(err);

+ 1 - 0
apps/app/src/features/openai/server/routes/edit/index.ts

@@ -246,6 +246,7 @@ export const postMessageToEditHandlersFactory = (
 
       const threadRelation = await ThreadRelationModel.findOne({
         threadId: { $eq: threadId },
+        userId: user._id,
       });
       if (threadRelation == null) {
         return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);

+ 4 - 2
apps/app/src/features/openai/server/routes/get-threads.ts

@@ -70,8 +70,10 @@ export const getThreadsFactory = (crowi: Crowi): RequestHandler[] => {
           );
         }
 
-        const threads =
-          await openaiService.getThreadsByAiAssistantId(aiAssistantId);
+        const threads = await openaiService.getThreadsByAiAssistantId(
+          aiAssistantId,
+          user._id,
+        );
 
         return res.apiv3({ threads });
       } catch (err) {

+ 9 - 0
apps/app/src/features/openai/server/routes/message/get-messages.ts

@@ -12,6 +12,7 @@ import loginRequiredFactory from '~/server/middlewares/login-required';
 import type { ApiV3Response } from '~/server/routes/apiv3/interfaces/apiv3-response';
 import loggerFactory from '~/utils/logger';
 
+import ThreadRelationModel from '../../models/thread-relation';
 import { getOpenaiService } from '../../services/openai';
 import { certifyAiService } from '../middlewares/certify-ai-service';
 
@@ -81,6 +82,14 @@ export const getMessagesFactory = (crowi: Crowi): RequestHandler[] => {
           );
         }
 
+        const threadRelation = await ThreadRelationModel.findOne({
+          threadId,
+          userId: user._id,
+        });
+        if (threadRelation == null) {
+          return res.apiv3Err(new ErrorV3('Thread not found'), 404);
+        }
+
         const messages = await openaiService.getMessageData(
           threadId,
           user.lang,

+ 4 - 1
apps/app/src/features/openai/server/routes/message/post-message.ts

@@ -128,7 +128,10 @@ export const postMessageHandlersFactory = (crowi: Crowi): RequestHandler[] => {
         return res.apiv3Err(new ErrorV3('AI assistant not found'), 404);
       }
 
-      const threadRelation = await ThreadRelationModel.findOne({ threadId });
+      const threadRelation = await ThreadRelationModel.findOne({
+        threadId,
+        userId: user._id,
+      });
       if (threadRelation == null) {
         return res.apiv3Err(new ErrorV3('ThreadRelation not found'), 404);
       }

+ 16 - 3
apps/app/src/features/openai/server/services/openai.ts

@@ -98,8 +98,12 @@ export interface IOpenaiService {
   ): Promise<ThreadRelationDocument>;
   getThreadsByAiAssistantId(
     aiAssistantId: string,
+    userId: string,
   ): Promise<ThreadRelationDocument[]>;
-  deleteThread(threadRelationId: string): Promise<ThreadRelationDocument>;
+  deleteThread(
+    threadRelationId: string,
+    userId: string,
+  ): Promise<ThreadRelationDocument>;
   deleteExpiredThreads(limit: number, apiCallInterval: number): Promise<void>; // for CronJob
   deleteObsoletedVectorStoreRelations(): Promise<void>; // for CronJob
   deleteVectorStore(vectorStoreRelationId: string): Promise<void>;
@@ -274,7 +278,10 @@ class OpenaiService implements IOpenaiService {
     aiAssistantId: string,
     vectorStoreId: string,
   ): Promise<void> {
-    const threadRelations = await this.getThreadsByAiAssistantId(aiAssistantId);
+    const threadRelations = await ThreadRelationModel.find({
+      aiAssistant: aiAssistantId,
+      type: ThreadType.KNOWLEDGE,
+    }).sort({ updatedAt: -1 });
     for await (const threadRelation of threadRelations) {
       try {
         const updatedThreadResponse = await this.client.updateThread(
@@ -290,10 +297,12 @@ class OpenaiService implements IOpenaiService {
 
   async getThreadsByAiAssistantId(
     aiAssistantId: string,
+    userId: string,
     type: ThreadType = ThreadType.KNOWLEDGE,
   ): Promise<ThreadRelationDocument[]> {
     const threadRelations = await ThreadRelationModel.find({
       aiAssistant: aiAssistantId,
+      userId,
       type,
     }).sort({ updatedAt: -1 });
     return threadRelations;
@@ -301,8 +310,12 @@ class OpenaiService implements IOpenaiService {
 
   async deleteThread(
     threadRelationId: string,
+    userId: string,
   ): Promise<ThreadRelationDocument> {
-    const threadRelation = await ThreadRelationModel.findById(threadRelationId);
+    const threadRelation = await ThreadRelationModel.findOne({
+      _id: threadRelationId,
+      userId,
+    });
     if (threadRelation == null) {
       throw createError(404, 'ThreadRelation document does not exist');
     }